Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0a4fb0da
Commit
0a4fb0da
authored
Jun 19, 2019
by
chrislarson1
Browse files
Merge remote-tracking branch 'upstream/master' into convert-back-to-tf
merging in latest changes from upstream
parents
314bc6bb
3763f894
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
512 additions
and
68 deletions
+512
-68
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+40
-17
pytorch_pretrained_bert/modeling_transfo_xl_utilities.py
pytorch_pretrained_bert/modeling_transfo_xl_utilities.py
+2
-2
pytorch_pretrained_bert/tokenization.py
pytorch_pretrained_bert/tokenization.py
+18
-7
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+23
-11
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+14
-9
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+13
-8
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+139
-4
tests/modeling_openai_test.py
tests/modeling_openai_test.py
+98
-1
tests/modeling_test.py
tests/modeling_test.py
+165
-9
No files found.
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
0a4fb0da
...
...
@@ -25,9 +25,6 @@ import copy
import
json
import
math
import
logging
import
tarfile
import
tempfile
import
shutil
import
collections
import
sys
from
io
import
open
...
...
@@ -888,8 +885,7 @@ class TransfoXLPreTrainedModel(nn.Module):
pass
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
state_dict
=
None
,
cache_dir
=
None
,
from_tf
=
False
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
...
...
@@ -897,19 +893,25 @@ class TransfoXLPreTrainedModel(nn.Module):
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl`
. `transfo-xl
-wt103
`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `
bert
_config.json` a configuration file for the model
. `
transfo_xl
_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
*inputs, **kwargs: additional input for the specific TransformerXL class
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
...
...
@@ -919,16 +921,37 @@ class TransfoXLPreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
)
)
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
,
config_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
config_file
)
)
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
...
...
pytorch_pretrained_bert/modeling_transfo_xl_utilities.py
View file @
0a4fb0da
...
...
@@ -114,10 +114,10 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
logit
=
self
.
_compute_logit
(
hidden
,
self
.
out_layers
[
0
].
weight
,
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
])
if
target
is
not
None
:
out
put
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
out
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
.
gather
(
1
,
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
else
:
out
put
=
F
.
log_softmax
(
logit
,
dim
=-
1
)
out
=
F
.
log_softmax
(
logit
,
dim
=-
1
)
else
:
# construct weights and biases
weights
,
biases
=
[],
[]
...
...
pytorch_pretrained_bert/tokenization.py
View file @
0a4fb0da
...
...
@@ -34,6 +34,9 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt"
,
'bert-base-german-cased'
:
"https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt"
,
'bert-large-uncased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt"
,
'bert-large-cased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt"
,
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'bert-base-uncased'
:
512
,
...
...
@@ -43,6 +46,9 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-multilingual-uncased'
:
512
,
'bert-base-multilingual-cased'
:
512
,
'bert-base-chinese'
:
512
,
'bert-base-german-cased'
:
512
,
'bert-large-uncased-whole-word-masking'
:
512
,
'bert-large-cased-whole-word-masking'
:
512
,
}
VOCAB_NAME
=
'vocab.txt'
...
...
@@ -175,13 +181,18 @@ class BertTokenizer(object):
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
vocab_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
vocab_file
))
return
None
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
...
...
pytorch_pretrained_bert/tokenization_gpt2.py
View file @
0a4fb0da
...
...
@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json"
,
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt"
,
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'gpt2'
:
1024
,
...
...
@@ -91,7 +93,7 @@ class GPT2Tokenizer(object):
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a
PreTrainedBertModel
from a pre-trained model file.
Instantiate a
GPT2Tokenizer
from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
...
...
@@ -111,14 +113,19 @@ class GPT2Tokenizer(object):
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
...
...
@@ -263,9 +270,14 @@ class GPT2Tokenizer(object):
def
encode
(
self
,
text
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
def
decode
(
self
,
tokens
):
text
=
''
.
join
(
[
self
.
decoder
[
token
]
for
token
in
tokens
]
)
def
decode
(
self
,
tokens
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
text
=
''
.
join
(
self
.
convert_ids_to_tokens
(
tokens
,
skip_special_tokens
=
skip_special_
tokens
)
)
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
if
clean_up_tokenization_spaces
:
text
=
text
.
replace
(
'<unk>'
,
''
)
text
=
text
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
text
def
save_vocabulary
(
self
,
vocab_path
):
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
0a4fb0da
...
...
@@ -101,14 +101,19 @@ class OpenAIGPTTokenizer(object):
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
...
...
@@ -272,7 +277,7 @@ class OpenAIGPTTokenizer(object):
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
if
clean_up_tokenization_spaces
:
out_string
=
out_string
.
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
' ,'
,
','
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
out_string
...
...
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
0a4fb0da
...
...
@@ -71,14 +71,19 @@ class TransfoXLTokenizer(object):
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
))
return
None
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
...
...
tests/modeling_gpt2_test.py
View file @
0a4fb0da
...
...
@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
n_special
=
1
,
n_positions
=
33
,
n_embd
=
32
,
n_layer
=
5
,
...
...
@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
n_positions
=
n_positions
self
.
n_embd
=
n_embd
self
.
n_layer
=
n_layer
...
...
@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
vocab_size
)
total_num_tokens
=
self
.
vocab_size
+
self
.
n_special
input_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_num_tokens
)
position_ids
=
None
if
self
.
use_position_ids
:
...
...
@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
config
=
GPT2Config
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_special
=
self
.
n_special
,
n_positions
=
self
.
n_positions
,
n_embd
=
self
.
n_embd
,
n_layer
=
self
.
n_layer
,
...
...
@@ -111,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase):
return
outputs
def
check_gpt2_model_output
(
self
,
result
):
self
.
parent
.
assertEqual
(
len
(
result
[
"hidden_states"
]),
self
.
n_layer
+
1
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states"
].
size
()),
list
(
result
[
"hidden_states"
]
[
0
]
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
...
...
@@ -129,11 +134,29 @@ class GPT2ModelTest(unittest.TestCase):
}
return
outputs
def
create_gpt2_lm_head_with_output_attention
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2LMHeadModel
(
config
,
output_attentions
=
True
)
model
.
eval
()
loss
=
model
(
input_ids
,
position_ids
,
token_type_ids
,
lm_labels
)
attentions
,
lm_logits
,
presents
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"presents"
:
presents
,
"attentions"
:
attentions
,
}
return
outputs
def
check_gpt2_lm_head_output
(
self
,
result
):
total_voc
=
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
self
.
parent
.
assertEqual
(
self
.
n_layer
,
len
(
result
[
"presents"
]))
self
.
parent
.
assertListEqual
(
list
(
result
[
"presents"
][
0
].
size
()),
[
2
,
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
def
check_gpt2_lm_head_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
...
...
@@ -156,8 +179,25 @@ class GPT2ModelTest(unittest.TestCase):
}
return
outputs
def
create_gpt2_double_heads_with_output_attention
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2DoubleHeadsModel
(
config
,
output_attentions
=
True
)
model
.
eval
()
loss
=
model
(
input_ids
,
mc_token_ids
,
lm_labels
=
lm_labels
,
mc_labels
=
mc_labels
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
attentions
,
lm_logits
,
mc_logits
,
presents
=
model
(
input_ids
,
mc_token_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"mc_logits"
:
mc_logits
,
"presents"
:
presents
,
"attentions"
:
attentions
,
}
return
outputs
def
check_gpt2_double_heads_output
(
self
,
result
):
total_voc
=
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
...
...
@@ -170,6 +210,98 @@ class GPT2ModelTest(unittest.TestCase):
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[[],
[]])
def
create_and_check_gpt2_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
zeros
(
self
.
n_layer
,
self
.
n_head
).
to
(
input_ids
.
device
)
head_mask
[
0
,
1
:
-
1
]
=
1.0
# Mask all but the first and last heads on the first layer
head_mask
[
-
1
,
1
:]
=
1.0
# Mask all but the first head on the last layer
if
isinstance
(
model
,
GPT2DoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
,
head_mask
=
head_mask
)
else
:
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
if
isinstance
(
model
,
GPT2Model
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
GPT2Model
)
else
model
.
transformer
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
n_head
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
n_head
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
1
].
nonzero
()),
multihead_outputs
[
1
].
numel
())
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
1
:,
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
def
create_and_check_gpt2_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
transformer
=
model
if
isinstance
(
model
,
GPT2Model
)
else
model
.
transformer
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
n_head
)),
-
1
:
[
0
]}
transformer
.
prune_heads
(
heads_to_prune
)
if
isinstance
(
model
,
GPT2DoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
)
else
:
output
=
model
(
input_ids
)
if
isinstance
(
model
,
GPT2Model
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
-
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
def
test_default
(
self
):
self
.
run_tester
(
GPT2ModelTest
.
GPT2ModelTester
(
self
))
...
...
@@ -208,6 +340,9 @@ class GPT2ModelTest(unittest.TestCase):
tester
.
check_gpt2_double_heads_output
(
output_result
)
tester
.
check_gpt2_double_heads_loss_output
(
output_result
)
tester
.
create_and_check_gpt2_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_gpt2_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
tests/modeling_openai_test.py
View file @
0a4fb0da
...
...
@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
return
outputs
def
check_openai_model_output
(
self
,
result
):
self
.
parent
.
assertEqual
(
len
(
result
[
"hidden_states"
]),
self
.
n_layer
+
1
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states"
].
size
()),
list
(
result
[
"hidden_states"
]
[
0
]
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
...
...
@@ -182,6 +183,99 @@ class OpenAIGPTModelTest(unittest.TestCase):
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[[],
[]])
def
create_and_check_openai_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
zeros
(
self
.
n_layer
,
self
.
n_head
).
to
(
input_ids
.
device
)
head_mask
[
0
,
1
:
-
1
]
=
1.0
# Mask all but the first and last heads on the first layer
head_mask
[
-
1
,
1
:]
=
1.0
# Mask all but the first head on the last layer
if
isinstance
(
model
,
OpenAIGPTDoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
,
head_mask
=
head_mask
)
else
:
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
if
isinstance
(
model
,
OpenAIGPTModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
OpenAIGPTModel
)
else
model
.
transformer
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
n_head
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
n_head
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
1
].
nonzero
()),
multihead_outputs
[
1
].
numel
())
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
1
:,
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
def
create_and_check_openai_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
transformer
=
model
if
isinstance
(
model
,
OpenAIGPTModel
)
else
model
.
transformer
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
n_head
)),
-
1
:
[
0
]}
transformer
.
prune_heads
(
heads_to_prune
)
if
isinstance
(
model
,
OpenAIGPTDoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
)
else
:
output
=
model
(
input_ids
)
if
isinstance
(
model
,
OpenAIGPTModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
-
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
def
test_default
(
self
):
self
.
run_tester
(
OpenAIGPTModelTest
.
OpenAIGPTModelTester
(
self
))
...
...
@@ -220,6 +314,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
tester
.
check_openai_double_heads_output
(
output_result
)
tester
.
check_openai_double_heads_loss_output
(
output_result
)
tester
.
create_and_check_openai_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_openai_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
tests/modeling_test.py
View file @
0a4fb0da
...
...
@@ -28,7 +28,7 @@ import torch
from
pytorch_pretrained_bert
import
(
BertConfig
,
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
)
BertForTokenClassification
,
BertForMultipleChoice
)
from
pytorch_pretrained_bert.modeling
import
PRETRAINED_MODEL_ARCHIVE_MAP
...
...
@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
...
...
@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
BertConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
...
@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertModel
(
config
=
config
)
model
.
eval
()
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
...
...
@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMaskedLM
(
config
=
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
...
...
@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
...
...
@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
2
])
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForPreTraining
(
config
=
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
...
...
@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
2
])
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForQuestionAnswering
(
config
=
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
...
...
@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
])
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
...
...
@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
num_labels
])
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
):
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
model
.
eval
()
loss
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
...
...
@@ -246,6 +250,150 @@ class BertModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
def
create_bert_for_multiple_choice
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMultipleChoice
(
config
=
config
,
num_choices
=
self
.
num_choices
)
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
,
choice_labels
)
logits
=
model
(
multiple_choice_inputs_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
)
outputs
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_multiple_choice
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
def
create_and_check_bert_for_attentions
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
output_attentions
=
True
)
else
:
model
=
model_class
(
config
=
config
,
output_attentions
=
True
)
model
.
eval
()
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
attentions
=
output
[
0
]
self
.
parent
.
assertEqual
(
len
(
attentions
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
seq_length
])
def
create_and_check_bert_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
keep_multihead_output
=
True
)
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
zeros
(
self
.
num_hidden_layers
,
self
.
num_attention_heads
).
to
(
input_ids
.
device
)
head_mask
[
0
,
1
:
-
1
]
=
1.0
# Mask all but the first and last heads on the first layer
head_mask
[
-
1
,
1
:]
=
1.0
# Mask all but the first head on the last layer
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
head_mask
=
head_mask
)
if
isinstance
(
model
,
BertModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
num_attention_heads
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
num_attention_heads
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
1
].
nonzero
()),
multihead_outputs
[
1
].
numel
())
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
1
:,
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
-
1
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
seq_length
*
self
.
hidden_size
//
self
.
num_attention_heads
)
def
create_and_check_bert_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
keep_multihead_output
=
True
)
else
:
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
bert_model
=
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
num_attention_heads
)),
-
1
:
[
0
]}
bert_model
.
prune_heads
(
heads_to_prune
)
output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
if
isinstance
(
model
,
BertModel
):
output
=
sum
(
t
.
sum
()
for
t
in
output
[
0
])
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
sum
(
t
.
sum
()
for
t
in
output
)
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
bert_model
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
,
1
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
-
1
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
...
@@ -300,6 +448,14 @@ class BertModelTest(unittest.TestCase):
tester
.
check_bert_for_token_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_multiple_choice
(
*
config_and_inputs
)
tester
.
check_bert_for_multiple_choice
(
output_result
)
tester
.
check_loss_output
(
output_result
)
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment