Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0a4fb0da
Commit
0a4fb0da
authored
Jun 19, 2019
by
chrislarson1
Browse files
Merge remote-tracking branch 'upstream/master' into convert-back-to-tf
merging in latest changes from upstream
parents
314bc6bb
3763f894
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
512 additions
and
68 deletions
+512
-68
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+40
-17
pytorch_pretrained_bert/modeling_transfo_xl_utilities.py
pytorch_pretrained_bert/modeling_transfo_xl_utilities.py
+2
-2
pytorch_pretrained_bert/tokenization.py
pytorch_pretrained_bert/tokenization.py
+18
-7
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+23
-11
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+14
-9
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+13
-8
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+139
-4
tests/modeling_openai_test.py
tests/modeling_openai_test.py
+98
-1
tests/modeling_test.py
tests/modeling_test.py
+165
-9
No files found.
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
0a4fb0da
...
@@ -25,9 +25,6 @@ import copy
...
@@ -25,9 +25,6 @@ import copy
import
json
import
json
import
math
import
math
import
logging
import
logging
import
tarfile
import
tempfile
import
shutil
import
collections
import
collections
import
sys
import
sys
from
io
import
open
from
io
import
open
...
@@ -888,8 +885,7 @@ class TransfoXLPreTrainedModel(nn.Module):
...
@@ -888,8 +885,7 @@ class TransfoXLPreTrainedModel(nn.Module):
pass
pass
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
state_dict
=
None
,
cache_dir
=
None
,
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
from_tf
=
False
,
*
inputs
,
**
kwargs
):
"""
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Download and cache the pre-trained model file if needed.
...
@@ -897,19 +893,25 @@ class TransfoXLPreTrainedModel(nn.Module):
...
@@ -897,19 +893,25 @@ class TransfoXLPreTrainedModel(nn.Module):
Params:
Params:
pretrained_model_name_or_path: either:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl`
. `transfo-xl
-wt103
`
- a path or url to a pretrained model archive containing:
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
- a path or url to a pretrained model archive containing:
. `
bert
_config.json` a configuration file for the model
. `
transfo_xl
_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
*inputs, **kwargs: additional input for the specific TransformerXL class
(ex: num_labels for BertForSequenceClassification)
"""
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
...
@@ -919,16 +921,37 @@ class TransfoXLPreTrainedModel(nn.Module):
...
@@ -919,16 +921,37 @@ class TransfoXLPreTrainedModel(nn.Module):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
)
)
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
except
EnvironmentError
:
logger
.
error
(
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
"Model name '{}' was not found in model name list ({}). "
logger
.
error
(
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
"at this path or url."
.
format
(
config_file
))
pretrained_model_name_or_path
,
else
:
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
logger
.
error
(
pretrained_model_name_or_path
,
"Model name '{}' was not found in model name list ({}). "
archive_file
,
config_file
))
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
config_file
)
)
return
None
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
...
...
pytorch_pretrained_bert/modeling_transfo_xl_utilities.py
View file @
0a4fb0da
...
@@ -114,10 +114,10 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -114,10 +114,10 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
logit
=
self
.
_compute_logit
(
hidden
,
self
.
out_layers
[
0
].
weight
,
logit
=
self
.
_compute_logit
(
hidden
,
self
.
out_layers
[
0
].
weight
,
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
])
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
])
if
target
is
not
None
:
if
target
is
not
None
:
out
put
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
out
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
.
gather
(
1
,
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
.
gather
(
1
,
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
else
:
else
:
out
put
=
F
.
log_softmax
(
logit
,
dim
=-
1
)
out
=
F
.
log_softmax
(
logit
,
dim
=-
1
)
else
:
else
:
# construct weights and biases
# construct weights and biases
weights
,
biases
=
[],
[]
weights
,
biases
=
[],
[]
...
...
pytorch_pretrained_bert/tokenization.py
View file @
0a4fb0da
...
@@ -34,6 +34,9 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
...
@@ -34,6 +34,9 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt"
,
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt"
,
'bert-base-german-cased'
:
"https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt"
,
'bert-large-uncased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt"
,
'bert-large-cased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt"
,
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'bert-base-uncased'
:
512
,
'bert-base-uncased'
:
512
,
...
@@ -43,6 +46,9 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
...
@@ -43,6 +46,9 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-multilingual-uncased'
:
512
,
'bert-base-multilingual-uncased'
:
512
,
'bert-base-multilingual-cased'
:
512
,
'bert-base-multilingual-cased'
:
512
,
'bert-base-chinese'
:
512
,
'bert-base-chinese'
:
512
,
'bert-base-german-cased'
:
512
,
'bert-large-uncased-whole-word-masking'
:
512
,
'bert-large-cased-whole-word-masking'
:
512
,
}
}
VOCAB_NAME
=
'vocab.txt'
VOCAB_NAME
=
'vocab.txt'
...
@@ -175,13 +181,18 @@ class BertTokenizer(object):
...
@@ -175,13 +181,18 @@ class BertTokenizer(object):
try
:
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
except
EnvironmentError
:
logger
.
error
(
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
"Model name '{}' was not found in model name list ({}). "
logger
.
error
(
"We assumed '{}' was a path or url but couldn't find any file "
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
"associated to this path or url."
.
format
(
vocab_file
))
pretrained_model_name_or_path
,
else
:
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
logger
.
error
(
vocab_file
))
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
vocab_file
))
return
None
return
None
if
resolved_vocab_file
==
vocab_file
:
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
...
...
pytorch_pretrained_bert/tokenization_gpt2.py
View file @
0a4fb0da
...
@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
...
@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json"
,
}
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt"
,
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'gpt2'
:
1024
,
'gpt2'
:
1024
,
...
@@ -91,7 +93,7 @@ class GPT2Tokenizer(object):
...
@@ -91,7 +93,7 @@ class GPT2Tokenizer(object):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
"""
Instantiate a
PreTrainedBertModel
from a pre-trained model file.
Instantiate a
GPT2Tokenizer
from a pre-trained model file.
Download and cache the pre-trained model file if needed.
Download and cache the pre-trained model file if needed.
"""
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
...
@@ -111,14 +113,19 @@ class GPT2Tokenizer(object):
...
@@ -111,14 +113,19 @@ class GPT2Tokenizer(object):
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
except
EnvironmentError
:
logger
.
error
(
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
"Model name '{}' was not found in model name list ({}). "
logger
.
error
(
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
"at this path or url."
.
format
(
vocab_file
))
pretrained_model_name_or_path
,
else
:
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
logger
.
error
(
pretrained_model_name_or_path
,
"Model name '{}' was not found in model name list ({}). "
vocab_file
,
merges_file
))
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
...
@@ -263,9 +270,14 @@ class GPT2Tokenizer(object):
...
@@ -263,9 +270,14 @@ class GPT2Tokenizer(object):
def
encode
(
self
,
text
):
def
encode
(
self
,
text
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
def
decode
(
self
,
tokens
):
def
decode
(
self
,
tokens
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
text
=
''
.
join
(
[
self
.
decoder
[
token
]
for
token
in
tokens
]
)
text
=
''
.
join
(
self
.
convert_ids_to_tokens
(
tokens
,
skip_special_tokens
=
skip_special_
tokens
)
)
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
if
clean_up_tokenization_spaces
:
text
=
text
.
replace
(
'<unk>'
,
''
)
text
=
text
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
text
return
text
def
save_vocabulary
(
self
,
vocab_path
):
def
save_vocabulary
(
self
,
vocab_path
):
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
0a4fb0da
...
@@ -101,14 +101,19 @@ class OpenAIGPTTokenizer(object):
...
@@ -101,14 +101,19 @@ class OpenAIGPTTokenizer(object):
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
except
EnvironmentError
:
logger
.
error
(
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
"Model name '{}' was not found in model name list ({}). "
logger
.
error
(
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
"at this path or url."
.
format
(
vocab_file
))
pretrained_model_name_or_path
,
else
:
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
logger
.
error
(
pretrained_model_name_or_path
,
"Model name '{}' was not found in model name list ({}). "
vocab_file
,
merges_file
))
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
...
@@ -272,7 +277,7 @@ class OpenAIGPTTokenizer(object):
...
@@ -272,7 +277,7 @@ class OpenAIGPTTokenizer(object):
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
if
clean_up_tokenization_spaces
:
if
clean_up_tokenization_spaces
:
out_string
=
out_string
.
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
' ,'
,
','
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
out_string
return
out_string
...
...
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
0a4fb0da
...
@@ -71,14 +71,19 @@ class TransfoXLTokenizer(object):
...
@@ -71,14 +71,19 @@ class TransfoXLTokenizer(object):
try
:
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
except
EnvironmentError
:
logger
.
error
(
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
"Model name '{}' was not found in model name list ({}). "
logger
.
error
(
"We assumed '{}' was a path or url but couldn't find files {} "
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
"at this path or url."
.
format
(
vocab_file
))
pretrained_model_name_or_path
,
else
:
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
logger
.
error
(
pretrained_model_name_or_path
,
"Model name '{}' was not found in model name list ({}). "
vocab_file
))
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
))
return
None
return
None
if
resolved_vocab_file
==
vocab_file
:
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
...
...
tests/modeling_gpt2_test.py
View file @
0a4fb0da
...
@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
use_token_type_ids
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
n_special
=
1
,
n_positions
=
33
,
n_positions
=
33
,
n_embd
=
32
,
n_embd
=
32
,
n_layer
=
5
,
n_layer
=
5
,
...
@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
n_positions
=
n_positions
self
.
n_positions
=
n_positions
self
.
n_embd
=
n_embd
self
.
n_embd
=
n_embd
self
.
n_layer
=
n_layer
self
.
n_layer
=
n_layer
...
@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
self
.
scope
=
scope
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
vocab_size
)
total_num_tokens
=
self
.
vocab_size
+
self
.
n_special
input_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_num_tokens
)
position_ids
=
None
position_ids
=
None
if
self
.
use_position_ids
:
if
self
.
use_position_ids
:
...
@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
config
=
GPT2Config
(
config
=
GPT2Config
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_special
=
self
.
n_special
,
n_positions
=
self
.
n_positions
,
n_positions
=
self
.
n_positions
,
n_embd
=
self
.
n_embd
,
n_embd
=
self
.
n_embd
,
n_layer
=
self
.
n_layer
,
n_layer
=
self
.
n_layer
,
...
@@ -111,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -111,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase):
return
outputs
return
outputs
def
check_gpt2_model_output
(
self
,
result
):
def
check_gpt2_model_output
(
self
,
result
):
self
.
parent
.
assertEqual
(
len
(
result
[
"hidden_states"
]),
self
.
n_layer
+
1
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states"
].
size
()),
list
(
result
[
"hidden_states"
]
[
0
]
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
...
@@ -129,11 +134,29 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -129,11 +134,29 @@ class GPT2ModelTest(unittest.TestCase):
}
}
return
outputs
return
outputs
def
create_gpt2_lm_head_with_output_attention
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2LMHeadModel
(
config
,
output_attentions
=
True
)
model
.
eval
()
loss
=
model
(
input_ids
,
position_ids
,
token_type_ids
,
lm_labels
)
attentions
,
lm_logits
,
presents
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"presents"
:
presents
,
"attentions"
:
attentions
,
}
return
outputs
def
check_gpt2_lm_head_output
(
self
,
result
):
def
check_gpt2_lm_head_output
(
self
,
result
):
total_voc
=
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
self
.
parent
.
assertEqual
(
self
.
n_layer
,
len
(
result
[
"presents"
]))
self
.
parent
.
assertListEqual
(
list
(
result
[
"presents"
][
0
].
size
()),
[
2
,
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
def
check_gpt2_lm_head_loss_output
(
self
,
result
):
def
check_gpt2_lm_head_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
...
@@ -156,8 +179,25 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -156,8 +179,25 @@ class GPT2ModelTest(unittest.TestCase):
}
}
return
outputs
return
outputs
def
create_gpt2_double_heads_with_output_attention
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2DoubleHeadsModel
(
config
,
output_attentions
=
True
)
model
.
eval
()
loss
=
model
(
input_ids
,
mc_token_ids
,
lm_labels
=
lm_labels
,
mc_labels
=
mc_labels
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
attentions
,
lm_logits
,
mc_logits
,
presents
=
model
(
input_ids
,
mc_token_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"mc_logits"
:
mc_logits
,
"presents"
:
presents
,
"attentions"
:
attentions
,
}
return
outputs
def
check_gpt2_double_heads_output
(
self
,
result
):
def
check_gpt2_double_heads_output
(
self
,
result
):
total_voc
=
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
...
@@ -170,6 +210,98 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -170,6 +210,98 @@ class GPT2ModelTest(unittest.TestCase):
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[[],
[]])
[[],
[]])
def
create_and_check_gpt2_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
zeros
(
self
.
n_layer
,
self
.
n_head
).
to
(
input_ids
.
device
)
head_mask
[
0
,
1
:
-
1
]
=
1.0
# Mask all but the first and last heads on the first layer
head_mask
[
-
1
,
1
:]
=
1.0
# Mask all but the first head on the last layer
if
isinstance
(
model
,
GPT2DoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
,
head_mask
=
head_mask
)
else
:
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
if
isinstance
(
model
,
GPT2Model
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
GPT2Model
)
else
model
.
transformer
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
n_head
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
n_head
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
1
].
nonzero
()),
multihead_outputs
[
1
].
numel
())
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
1
:,
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
def
create_and_check_gpt2_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
transformer
=
model
if
isinstance
(
model
,
GPT2Model
)
else
model
.
transformer
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
n_head
)),
-
1
:
[
0
]}
transformer
.
prune_heads
(
heads_to_prune
)
if
isinstance
(
model
,
GPT2DoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
)
else
:
output
=
model
(
input_ids
)
if
isinstance
(
model
,
GPT2Model
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
-
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
GPT2ModelTest
.
GPT2ModelTester
(
self
))
self
.
run_tester
(
GPT2ModelTest
.
GPT2ModelTester
(
self
))
...
@@ -208,6 +340,9 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -208,6 +340,9 @@ class GPT2ModelTest(unittest.TestCase):
tester
.
check_gpt2_double_heads_output
(
output_result
)
tester
.
check_gpt2_double_heads_output
(
output_result
)
tester
.
check_gpt2_double_heads_loss_output
(
output_result
)
tester
.
check_gpt2_double_heads_loss_output
(
output_result
)
tester
.
create_and_check_gpt2_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_gpt2_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
tests/modeling_openai_test.py
View file @
0a4fb0da
...
@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
return
outputs
return
outputs
def
check_openai_model_output
(
self
,
result
):
def
check_openai_model_output
(
self
,
result
):
self
.
parent
.
assertEqual
(
len
(
result
[
"hidden_states"
]),
self
.
n_layer
+
1
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states"
].
size
()),
list
(
result
[
"hidden_states"
]
[
0
]
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
...
@@ -182,6 +183,99 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -182,6 +183,99 @@ class OpenAIGPTModelTest(unittest.TestCase):
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[[],
[]])
[[],
[]])
def
create_and_check_openai_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
zeros
(
self
.
n_layer
,
self
.
n_head
).
to
(
input_ids
.
device
)
head_mask
[
0
,
1
:
-
1
]
=
1.0
# Mask all but the first and last heads on the first layer
head_mask
[
-
1
,
1
:]
=
1.0
# Mask all but the first head on the last layer
if
isinstance
(
model
,
OpenAIGPTDoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
,
head_mask
=
head_mask
)
else
:
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
if
isinstance
(
model
,
OpenAIGPTModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
OpenAIGPTModel
)
else
model
.
transformer
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
n_head
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
n_head
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
1
].
nonzero
()),
multihead_outputs
[
1
].
numel
())
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
1
:,
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
def
create_and_check_openai_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
transformer
=
model
if
isinstance
(
model
,
OpenAIGPTModel
)
else
model
.
transformer
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
n_head
)),
-
1
:
[
0
]}
transformer
.
prune_heads
(
heads_to_prune
)
if
isinstance
(
model
,
OpenAIGPTDoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
)
else
:
output
=
model
(
input_ids
)
if
isinstance
(
model
,
OpenAIGPTModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
-
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
OpenAIGPTModelTest
.
OpenAIGPTModelTester
(
self
))
self
.
run_tester
(
OpenAIGPTModelTest
.
OpenAIGPTModelTester
(
self
))
...
@@ -220,6 +314,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -220,6 +314,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
tester
.
check_openai_double_heads_output
(
output_result
)
tester
.
check_openai_double_heads_output
(
output_result
)
tester
.
check_openai_double_heads_loss_output
(
output_result
)
tester
.
check_openai_double_heads_loss_output
(
output_result
)
tester
.
create_and_check_openai_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_openai_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
tests/modeling_test.py
View file @
0a4fb0da
...
@@ -28,7 +28,7 @@ import torch
...
@@ -28,7 +28,7 @@ import torch
from
pytorch_pretrained_bert
import
(
BertConfig
,
BertModel
,
BertForMaskedLM
,
from
pytorch_pretrained_bert
import
(
BertConfig
,
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
)
BertForTokenClassification
,
BertForMultipleChoice
)
from
pytorch_pretrained_bert.modeling
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_pretrained_bert.modeling
import
PRETRAINED_MODEL_ARCHIVE_MAP
...
@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
type_sequence_label_size
=
2
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
):
scope
=
None
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
...
@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
...
@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
sequence_labels
=
None
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
sequence_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
sequence_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
token_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
BertConfig
(
config
=
BertConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
...
@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
initializer_range
=
self
.
initializer_range
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
check_loss_output
(
self
,
result
):
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
list
(
result
[
"loss"
].
size
()),
[])
[])
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertModel
(
config
=
config
)
model
=
BertModel
(
config
=
config
)
model
.
eval
()
model
.
eval
()
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
...
@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMaskedLM
(
config
=
config
)
model
=
BertForMaskedLM
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
...
@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
list
(
result
[
"prediction_scores"
].
size
()),
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
...
@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
2
])
[
self
.
batch_size
,
2
])
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForPreTraining
(
config
=
config
)
model
=
BertForPreTraining
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
...
@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
2
])
[
self
.
batch_size
,
2
])
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForQuestionAnswering
(
config
=
config
)
model
=
BertForQuestionAnswering
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
...
@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
])
[
self
.
batch_size
,
self
.
seq_length
])
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
...
@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
num_labels
])
[
self
.
batch_size
,
self
.
num_labels
])
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
...
@@ -246,6 +250,150 @@ class BertModelTest(unittest.TestCase):
...
@@ -246,6 +250,150 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
def
create_bert_for_multiple_choice
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMultipleChoice
(
config
=
config
,
num_choices
=
self
.
num_choices
)
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
,
choice_labels
)
logits
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
)
outputs
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_multiple_choice
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
def
create_and_check_bert_for_attentions
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
output_attentions
=
True
)
else
:
model
=
model_class
(
config
=
config
,
output_attentions
=
True
)
model
.
eval
()
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
attentions
=
output
[
0
]
self
.
parent
.
assertEqual
(
len
(
attentions
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
seq_length
])
def
create_and_check_bert_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
keep_multihead_output
=
True
)
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
zeros
(
self
.
num_hidden_layers
,
self
.
num_attention_heads
).
to
(
input_ids
.
device
)
head_mask
[
0
,
1
:
-
1
]
=
1.0
# Mask all but the first and last heads on the first layer
head_mask
[
-
1
,
1
:]
=
1.0
# Mask all but the first head on the last layer
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
head_mask
=
head_mask
)
if
isinstance
(
model
,
BertModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
num_attention_heads
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
num_attention_heads
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
1
].
nonzero
()),
multihead_outputs
[
1
].
numel
())
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
1
:,
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
def
create_and_check_bert_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
keep_multihead_output
=
True
)
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
bert_model
=
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
num_attention_heads
)),
-
1
:
[
0
]}
bert_model
.
prune_heads
(
heads_to_prune
)
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
if
isinstance
(
model
,
BertModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
bert_model
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
,
1
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
-
1
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
@@ -300,6 +448,14 @@ class BertModelTest(unittest.TestCase):
...
@@ -300,6 +448,14 @@ class BertModelTest(unittest.TestCase):
tester
.
check_bert_for_token_classification_output
(
output_result
)
tester
.
check_bert_for_token_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_multiple_choice
(
*
config_and_inputs
)
tester
.
check_bert_for_multiple_choice
(
output_result
)
tester
.
check_loss_output
(
output_result
)
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment