Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1484d67d
Commit
1484d67d
authored
Jul 02, 2019
by
thomwolf
Browse files
[LARGE] updating all tests and API
parent
4f8b5f68
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1479 additions
and
1347 deletions
+1479
-1347
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+29
-47
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+47
-65
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+90
-212
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+78
-82
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+147
-57
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+95
-688
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+64
-76
pytorch_pretrained_bert/tests/__init__.py
pytorch_pretrained_bert/tests/__init__.py
+0
-0
pytorch_pretrained_bert/tests/conftest.py
pytorch_pretrained_bert/tests/conftest.py
+0
-0
pytorch_pretrained_bert/tests/fixtures/input.txt
pytorch_pretrained_bert/tests/fixtures/input.txt
+0
-0
pytorch_pretrained_bert/tests/fixtures/sample_text.txt
pytorch_pretrained_bert/tests/fixtures/sample_text.txt
+0
-0
pytorch_pretrained_bert/tests/fixtures/test_sentencepiece.model
...h_pretrained_bert/tests/fixtures/test_sentencepiece.model
+0
-0
pytorch_pretrained_bert/tests/model_tests_commons.py
pytorch_pretrained_bert/tests/model_tests_commons.py
+379
-0
pytorch_pretrained_bert/tests/model_utils_test.py
pytorch_pretrained_bert/tests/model_utils_test.py
+50
-0
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
+55
-0
pytorch_pretrained_bert/tests/modeling_openai_test.py
pytorch_pretrained_bert/tests/modeling_openai_test.py
+55
-0
pytorch_pretrained_bert/tests/modeling_test.py
pytorch_pretrained_bert/tests/modeling_test.py
+307
-0
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
+45
-70
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
+38
-50
pytorch_pretrained_bert/tests/optimization_test.py
pytorch_pretrained_bert/tests/optimization_test.py
+0
-0
No files found.
pytorch_pretrained_bert/model_utils.py
View file @
1484d67d
...
...
@@ -41,6 +41,12 @@ class PretrainedConfig(object):
"""
pretrained_config_archive_map
=
{}
def
__init__
(
self
,
**
kwargs
):
self
.
finetuning_task
=
kwargs
.
pop
(
'finetuning_task'
,
None
)
self
.
num_labels
=
kwargs
.
pop
(
'num_labels'
,
2
)
self
.
output_attentions
=
kwargs
.
pop
(
'output_attentions'
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
'output_hidden_states'
,
False
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
**
kwargs
):
"""
...
...
@@ -114,6 +120,9 @@ class PretrainedConfig(object):
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
def
__eq__
(
self
,
other
):
return
self
.
__dict__
==
other
.
__dict__
def
__repr__
(
self
):
return
str
(
self
.
to_json_string
())
...
...
@@ -133,12 +142,11 @@ class PretrainedConfig(object):
class
PreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle
weights initialization
and
""" An abstract class to handle
storing model config
and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
PretrainedConfig
pretrained_model_archive_map
=
{}
pretrained_config_archive_map
=
{}
load_tf_weights
=
lambda
model
,
config
,
path
:
None
base_model_prefix
=
""
...
...
@@ -151,8 +159,16 @@ class PreTrainedModel(nn.Module):
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
# Save config in model
self
.
config
=
config
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the base model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
model_to_prune
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
# get the base model if needed
model_to_prune
.
_prune_heads
(
heads_to_prune
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
...
...
@@ -175,24 +191,22 @@ class PreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'
stat
e_di
ct
'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir
'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
state_dict
=
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
pop
(
'
cach
e_di
r
'
,
None
)
from_tf
=
kwargs
.
pop
(
'from_tf
'
,
None
)
# Load config
config
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
# Load model
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
...
...
@@ -210,47 +224,15 @@ class PreTrainedModel(nn.Module):
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
cls
.
config_class
.
from_json_file
(
resolved_config_file
)
# Update config with kwargs if needed
to_remove
=
[]
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
to_remove
.
append
(
key
)
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
...
...
@@ -275,7 +257,7 @@ class PreTrainedModel(nn.Module):
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
#
B
e able to load base models as well as derived models (with heads)
#
Make sure we ar
e able to load base models as well as derived models (with heads)
start_prefix
=
''
model_to_load
=
model
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
...
...
pytorch_pretrained_bert/modeling.py
View file @
1484d67d
...
...
@@ -155,7 +155,7 @@ class BertConfig(PretrainedConfig):
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
vocab_size_or_config_json_file
,
vocab_size_or_config_json_file
=
30522
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
...
...
@@ -167,7 +167,7 @@ class BertConfig(PretrainedConfig):
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
finetuning_task
=
None
):
**
kwargs
):
"""Constructs BertConfig.
Args:
...
...
@@ -192,8 +192,8 @@ class BertConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
super
(
BertConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
...
...
@@ -213,7 +213,6 @@ class BertConfig(PretrainedConfig):
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
finetuning_task
=
finetuning_task
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -270,13 +269,13 @@ class BertEmbeddings(nn.Module):
class
BertSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
BertSelfAttention
,
self
).
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
...
...
@@ -344,10 +343,9 @@ class BertSelfOutput(nn.Module):
class
BertAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
BertAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
self
=
BertSelfAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
self
=
BertSelfAttention
(
config
)
self
.
output
=
BertSelfOutput
(
config
)
def
prune_heads
(
self
,
heads
):
...
...
@@ -404,10 +402,9 @@ class BertOutput(nn.Module):
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
BertLayer
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
attention
=
BertAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
attention
=
BertAttention
(
config
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
output
=
BertOutput
(
config
)
...
...
@@ -421,11 +418,11 @@ class BertLayer(nn.Module):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertEncoder
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
layer
=
BertLayer
(
config
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
...
...
@@ -546,9 +543,6 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_bert
base_model_prefix
=
"bert"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
BertPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
"""
...
...
@@ -612,19 +606,19 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
pooler
=
BertPooler
(
config
)
self
.
apply
(
self
.
init_weights
)
def
prune_heads
(
self
,
heads_to_prune
):
def
_
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
...
...
@@ -730,14 +724,12 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForPreTraining
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
bert
=
BertModel
(
config
)
self
.
cls
=
BertPreTrainingHeads
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
...
...
@@ -809,13 +801,12 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForMaskedLM
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
cls
=
BertOnlyMLMHead
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
head_mask
=
None
):
...
...
@@ -880,12 +871,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForNextSentencePrediction
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
cls
=
BertOnlyNSPHead
(
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -954,15 +943,13 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
num_labels
=
num_labels
self
.
num_labels
=
config
.
num_labels
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
config
.
num_labels
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -997,7 +984,6 @@ class BertForMultipleChoice(BertPreTrainedModel):
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
...
...
@@ -1030,25 +1016,23 @@ class BertForMultipleChoice(BertPreTrainedModel):
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices)
model = BertForMultipleChoice(config)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_choices
=
2
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForMultipleChoice
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
num_choices
=
num_choices
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
""" Input shapes should be [bsz, num choices, seq length] """
num_choices
=
input_ids
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
...
...
@@ -1057,7 +1041,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
self
.
num_choices
)
reshaped_logits
=
logits
.
view
(
-
1
,
num_choices
)
outputs
=
[
reshaped_logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
...
...
@@ -1118,15 +1102,13 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForTokenClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
num_labels
=
num_labels
self
.
num_labels
=
config
.
num_labels
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -1204,12 +1186,12 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attention
s
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
num_labels
=
config
.
num_label
s
self
.
bert
=
BertModel
(
config
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
apply
(
self
.
init_weights
)
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
1484d67d
...
...
@@ -119,7 +119,8 @@ class GPT2Config(PretrainedConfig):
attn_pdrop
=
0.1
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
predict_special_tokens
=
True
predict_special_tokens
=
True
,
**
kwargs
):
"""Constructs GPT2Config.
...
...
@@ -142,6 +143,8 @@ class GPT2Config(PretrainedConfig):
initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
"""
super
(
GPT2Config
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
...
...
@@ -174,8 +177,10 @@ class GPT2Config(PretrainedConfig):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
):
super
(
Attention
,
self
).
__init__
()
self
.
output_attentions
=
config
.
output_attentions
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert
n_state
%
config
.
n_head
==
0
...
...
@@ -184,10 +189,6 @@ class Attention(nn.Module):
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
...
...
@@ -224,9 +225,10 @@ class Attention(nn.Module):
if
head_mask
is
not
None
:
w
=
w
*
head_mask
outputs
=
[
torch
.
matmul
(
w
,
v
)]
if
self
.
output_attentions
:
return
w
,
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
outputs
.
append
(
w
)
return
outputs
def
merge_heads
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
...
...
@@ -253,19 +255,15 @@ class Attention(nn.Module):
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
# transpose to have same shapes for stacking
a
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
if
self
.
keep_multihead_output
:
self
.
multihead_output
=
a
self
.
multihead_output
.
retain_grad
()
attn_outputs
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
a
=
attn_outputs
[
0
]
if
self
.
output_attentions
:
attentions
,
a
=
a
a
=
self
.
merge_heads
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
resid_dropout
(
a
)
if
self
.
output_attentions
:
return
attentions
,
a
,
present
return
a
,
present
outputs
=
[
a
,
present
]
+
attn_outputs
[
1
:]
return
outputs
# a, present, (attentions)
class
MLP
(
nn
.
Module
):
...
...
@@ -284,27 +282,24 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
):
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
self
.
output_attentions
=
output_attentions
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
,
keep_multihead_output
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
def
forward
(
self
,
x
,
layer_past
=
None
,
head_mask
=
None
):
output_attn
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_past
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
attentions
,
a
,
present
=
output_attn
else
:
a
,
present
=
output_attn
a
=
output_attn
[
0
]
# output_attn: a, present, (attentions)
x
=
x
+
a
m
=
self
.
mlp
(
self
.
ln_2
(
x
))
x
=
x
+
m
if
self
.
output_attentions
:
return
attentions
,
x
,
present
return
x
,
present
outputs
=
[
x
]
+
output_attn
[
1
:]
return
outputs
# x, present, (attentions)
class
GPT2LMHead
(
nn
.
Module
):
...
...
@@ -342,12 +337,17 @@ class GPT2MultipleChoiceHead(nn.Module):
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
def
forward
(
self
,
hidden_states
,
mc_token_ids
):
# Classification logits
# hidden_state (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices)
def
forward
(
self
,
hidden_states
,
mc_token_ids
=
None
):
""" Extract classification token hidden state and project it using self.linear
hidden_state: shape (bsz, num_choices, seq_length, hidden_size)
mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
if mc_token_ids=None we take the last token of the sequence as classification token
"""
if
mc_token_ids
is
None
:
mc_token_ids
=
torch
.
full_like
(
hidden_states
[:,
:,
:
1
,
:],
hidden_states
.
shape
[
2
]
-
1
,
dtype
=
torch
.
long
)
else
:
mc_token_ids
=
mc_token_ids
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
hidden_states
.
size
(
-
1
))
# (bsz, num_choices, 1, hidden_size)
#
mc_token_ids has shape
(bsz, num_choices, 1, hidden_size)
multiple_choice_h
=
hidden_states
.
gather
(
2
,
mc_token_ids
).
squeeze
(
2
)
# (bsz, num_choices, hidden_size)
multiple_choice_h
=
self
.
dropout
(
multiple_choice_h
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
...
...
@@ -362,13 +362,9 @@ class GPT2PreTrainedModel(PreTrainedModel):
"""
config_class
=
GPT2Config
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_gpt2
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
GPT2PreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
"""
...
...
@@ -403,126 +399,9 @@ class GPT2PreTrainedModel(PreTrainedModel):
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific GPT2 class
"""
# state_dict = kwargs.get('state_dict', None)
# kwargs.pop('state_dict', None)
# cache_dir = kwargs.get('cache_dir', None)
# kwargs.pop('cache_dir', None)
# from_tf = kwargs.get('from_tf', False)
# kwargs.pop('from_tf', None)
num_special_tokens
=
kwargs
.
get
(
'num_special_tokens'
,
None
)
kwargs
.
pop
(
'num_special_tokens'
,
None
)
# if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
# archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
# config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
# else:
# archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
# config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# # redirect to the cache, if necessary
# try:
# resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
# except EnvironmentError:
# if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
# logger.error(
# "Couldn't reach server at '{}' to download pretrained weights.".format(
# archive_file))
# else:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find file {} "
# "at this path or url.".format(
# pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
# archive_file
# )
# )
# return None
# try:
# resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
# except EnvironmentError:
# if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
# logger.error(
# "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
# config_file))
# else:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find file {} "
# "at this path or url.".format(
# pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
# config_file
# )
# )
# return None
# if resolved_archive_file == archive_file and resolved_config_file == config_file:
# logger.info("loading weights file {}".format(archive_file))
# logger.info("loading configuration file {}".format(config_file))
# else:
# logger.info("loading weights file {} from cache at {}".format(
# archive_file, resolved_archive_file))
# logger.info("loading configuration file {} from cache at {}".format(
# config_file, resolved_config_file))
# # Load config
# config = GPT2Config.from_json_file(resolved_config_file)
# logger.info("Model config {}".format(config))
# # Instantiate model.
# model = cls(config, *inputs, **kwargs)
# if state_dict is None and not from_tf:
# state_dict = torch.load(resolved_archive_file, map_location='cpu')
# if from_tf:
# # Directly load from a TensorFlow checkpoint (stored as NumPy array)
# return load_tf_weights_in_gpt2(model, resolved_archive_file)
# old_keys = []
# new_keys = []
# for key in state_dict.keys():
# new_key = None
# if key.endswith(".g"):
# new_key = key[:-2] + ".weight"
# elif key.endswith(".b"):
# new_key = key[:-2] + ".bias"
# elif key.endswith(".w"):
# new_key = key[:-2] + ".weight"
# if new_key:
# old_keys.append(key)
# new_keys.append(new_key)
# for old_key, new_key in zip(old_keys, new_keys):
# state_dict[new_key] = state_dict.pop(old_key)
# missing_keys = []
# unexpected_keys = []
# error_msgs = []
# # copy state_dict so _load_from_state_dict can modify it
# metadata = getattr(state_dict, "_metadata", None)
# state_dict = state_dict.copy()
# if metadata is not None:
# state_dict._metadata = metadata
# def load(module, prefix=""):
# local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# module._load_from_state_dict(
# state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
# )
# for name, child in module._modules.items():
# if child is not None:
# load(child, prefix + name + ".")
# start_model = model
# if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
# start_model = model.transformer
# load(start_model, prefix="")
# if len(missing_keys) > 0:
# logger.info(
# "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
# )
# if len(unexpected_keys) > 0:
# logger.info(
# "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
# )
# if len(error_msgs) > 0:
# raise RuntimeError(
# "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
# )
num_special_tokens
=
kwargs
.
pop
(
'num_special_tokens'
,
None
)
model
=
PreTrainedModel
.
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
...
...
@@ -553,8 +432,6 @@ class GPT2Model(GPT2PreTrainedModel):
Params:
`config`: a GPT2Config class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
...
...
@@ -591,14 +468,15 @@ class GPT2Model(GPT2PreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
GPT2Model
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_attentions
=
config
.
output_attentions
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -618,19 +496,13 @@ class GPT2Model(GPT2PreTrainedModel):
# Copy word embeddings from the previous weights
self
.
wte
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
def
prune_heads
(
self
,
heads_to_prune
):
def
_
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
h
.
attn
.
multihead_output
for
h
in
self
.
h
]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
past
=
None
,
head_mask
=
None
):
if
past
is
None
:
past_length
=
0
...
...
@@ -675,20 +547,32 @@ class GPT2Model(GPT2PreTrainedModel):
all_attentions
=
[]
all_hidden_states
=
[]
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
[
i
])
if
self
.
output_attentions
:
attentions
,
hidden_states
,
present
=
outputs
all_attentions
.
append
(
attentions
)
else
:
hidden_states
,
present
=
outputs
hidden_states
,
present
=
outputs
[:
2
]
presents
.
append
(
present
)
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
2
])
hidden_states
=
self
.
ln_f
(
hidden_states
)
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
hidden_states
=
hidden_states
.
view
(
*
output_shape
)
# Add last hidden state
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
outputs
=
[
hidden_states
,
presents
]
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
if
self
.
output_attentions
:
return
all_attentions
,
all_hidden_states
,
presents
return
all_hidden_states
,
presents
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
all_attentions
=
list
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
.
append
(
all_attentions
)
return
outputs
# last hidden state, presents, (all hidden_states), (attentions)
class
GPT2LMHeadModel
(
GPT2PreTrainedModel
):
...
...
@@ -740,10 +624,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
GPT2LMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
transformer
=
GPT2Model
(
config
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -756,14 +639,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
hidden_states
,
presents
=
transformer_output
hidden_states
=
hidden_states
[
-
1
]
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
[
lm_logits
]
+
transformer_outputs
[
1
:]
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
...
...
@@ -772,10 +653,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
return
loss
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
,
presents
return
lm_logits
,
presents
outputs
=
[
loss
]
+
outputs
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
class
GPT2DoubleHeadsModel
(
GPT2PreTrainedModel
):
...
...
@@ -832,12 +712,12 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
GPT2DoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
transformer
=
GPT2Model
(
config
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
multiple_choice_head
=
GPT2MultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
,
predict_special_tokens
=
True
):
...
...
@@ -848,28 +728,26 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
hidden_states
,
presents
=
transformer_output
hidden_states
=
hidden_states
[
-
1
]
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
outputs
=
[
lm_logits
,
mc_logits
]
+
transformer_outputs
[
1
:]
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
)))
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
if
losses
:
return
losses
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
,
mc_logits
,
presents
return
lm_logits
,
mc_logits
,
presents
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
pytorch_pretrained_bert/modeling_openai.py
View file @
1484d67d
...
...
@@ -147,7 +147,8 @@ class OpenAIGPTConfig(PretrainedConfig):
attn_pdrop
=
0.1
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
predict_special_tokens
=
True
predict_special_tokens
=
True
,
**
kwargs
):
"""Constructs OpenAIGPTConfig.
...
...
@@ -172,6 +173,8 @@ class OpenAIGPTConfig(PretrainedConfig):
initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
"""
super
(
OpenAIGPTConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
...
...
@@ -205,7 +208,7 @@ class OpenAIGPTConfig(PretrainedConfig):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
):
super
(
Attention
,
self
).
__init__
()
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
...
...
@@ -215,9 +218,7 @@ class Attention(nn.Module):
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
output_attentions
=
config
.
output_attentions
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
...
...
@@ -256,9 +257,10 @@ class Attention(nn.Module):
if
head_mask
is
not
None
:
w
=
w
*
head_mask
outputs
=
[
torch
.
matmul
(
w
,
v
)]
if
self
.
output_attentions
:
return
w
,
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
outputs
.
append
(
w
)
return
outputs
def
merge_heads
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
...
...
@@ -280,19 +282,15 @@ class Attention(nn.Module):
key
=
self
.
split_heads
(
key
,
k
=
True
)
value
=
self
.
split_heads
(
value
)
a
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
if
self
.
keep_multihead_output
:
self
.
multihead_output
=
a
self
.
multihead_output
.
retain_grad
()
attn_outputs
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
a
=
attn_outputs
[
0
]
if
self
.
output_attentions
:
attentions
,
a
=
a
a
=
self
.
merge_heads
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
resid_dropout
(
a
)
if
self
.
output_attentions
:
return
attentions
,
a
return
a
outputs
=
[
a
]
+
attn_outputs
[
1
:]
return
outputs
# a, (attentions)
class
MLP
(
nn
.
Module
):
...
...
@@ -311,25 +309,24 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
):
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
self
.
output_attentions
=
output_attentions
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
,
keep_multihead_output
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
)
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
x
,
head_mask
=
None
):
a
=
self
.
attn
(
x
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
attentions
,
a
=
a
a
ttn_outputs
=
self
.
attn
(
x
,
head_mask
=
head_mask
)
a
=
attn_outputs
[
0
]
n
=
self
.
ln_1
(
x
+
a
)
m
=
self
.
mlp
(
n
)
h
=
self
.
ln_2
(
n
+
m
)
if
self
.
output_attentions
:
return
attentions
,
h
return
h
outputs
=
[
h
]
+
attn_outputs
[
1
:]
return
outputs
class
OpenAIGPTLMHead
(
nn
.
Module
):
...
...
@@ -368,10 +365,15 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
def
forward
(
self
,
hidden_states
,
mc_token_ids
):
# Classification logits
# hidden_state (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices)
def
forward
(
self
,
hidden_states
,
mc_token_ids
=
None
):
""" Extract classification token hidden state and project it using self.linear
hidden_state: hidden state of shape (bsz, num_choices, seq_length, hidden_size)
mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
if mc_token_ids=None we take the last token of the sequence as classification token
"""
if
mc_token_ids
is
None
:
mc_token_ids
=
torch
.
full_like
(
hidden_states
[:,
:,
:
1
,
:],
hidden_states
.
shape
[
2
]
-
1
,
dtype
=
torch
.
long
)
else
:
mc_token_ids
=
mc_token_ids
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
hidden_states
.
size
(
-
1
))
# (bsz, num_choices, 1, hidden_size)
multiple_choice_h
=
hidden_states
.
gather
(
2
,
mc_token_ids
).
squeeze
(
2
)
...
...
@@ -388,13 +390,9 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
"""
config_class
=
OpenAIGPTConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_openai_gpt
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
"""
...
...
@@ -495,14 +493,15 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
apply
(
self
.
init_weights
)
...
...
@@ -521,19 +520,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Copy word embeddings from the previous weights
self
.
tokens_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
def
prune_heads
(
self
,
heads_to_prune
):
def
_
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
h
.
attn
.
multihead_output
for
h
in
self
.
h
]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
head_mask
=
None
):
if
position_ids
is
None
:
# This was used when we had a single embedding matrice from position and token embeddings
...
...
@@ -574,19 +567,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
all_attentions
=
[]
all_hidden_states
=
[
hidden_states
.
view
(
*
output_shape
)
]
all_hidden_states
=
[]
for
i
,
block
in
enumerate
(
self
.
h
):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
outputs
=
block
(
hidden_states
,
head_mask
[
i
])
hidden_states
=
outputs
[
0
]
if
self
.
output_attentions
:
attentions
,
hidden_states
=
outputs
all_attentions
.
append
(
attentions
)
else
:
hidden_states
=
outputs
all_
attentions
.
append
(
outputs
[
1
])
# Add last layer
if
self
.
output_
hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
outputs
=
[
hidden_states
.
view
(
*
output_shape
)]
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
if
self
.
output_attentions
:
return
all_attentions
,
all_hidden_states
return
all
_
hidden
_
states
outputs
.
append
(
all_attentions
)
return
outputs
# last hidden state, (
all
hidden
states
), (all attentions)
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
...
...
@@ -650,10 +650,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -666,12 +665,11 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
head_mask
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
=
hidden_states
hidden_states
=
hidden_states
[
-
1
]
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
[
lm_logits
]
+
transformer_outputs
[
1
:]
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
...
...
@@ -680,10 +678,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
return
loss
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
return
lm_logits
outputs
=
[
loss
]
+
outputs
return
outputs
# (loss), lm_logits, (all hidden states), (all attentions)
class
OpenAIGPTDoubleHeadsModel
(
OpenAIGPTPreTrainedModel
):
...
...
@@ -752,10 +749,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -768,26 +764,26 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
=
hidden_states
hidden_states
=
hidden_states
[
-
1
]
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
outputs
=
[
lm_logits
,
mc_logits
]
+
transformer_outputs
[
1
:]
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
)))
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
if
losses
:
return
losses
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
,
mc_logits
return
lm_logits
,
mc_logits
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
1484d67d
...
...
@@ -209,7 +209,8 @@ class TransfoXLConfig(PretrainedConfig):
init
=
"normal"
,
init_range
=
0.01
,
proj_init_std
=
0.01
,
init_std
=
0.02
):
init_std
=
0.02
,
**
kwargs
):
"""Constructs TransfoXLConfig.
Args:
...
...
@@ -244,6 +245,8 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
"""
super
(
TransfoXLConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
...
...
@@ -287,6 +290,7 @@ class TransfoXLConfig(PretrainedConfig):
"or the path to a pretrained model config file (str)"
)
class
PositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
demb
):
super
(
PositionalEmbedding
,
self
).
__init__
()
...
...
@@ -306,6 +310,7 @@ class PositionalEmbedding(nn.Module):
return
pos_emb
[:,
None
,:]
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
PositionwiseFF
,
self
).
__init__
()
...
...
@@ -341,11 +346,14 @@ class PositionwiseFF(nn.Module):
return
output
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
):
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
...
...
@@ -371,7 +379,7 @@ class MultiHeadAttn(nn.Module):
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
...
...
@@ -404,6 +412,10 @@ class MultiHeadAttn(nn.Module):
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
...
...
@@ -415,19 +427,23 @@ class MultiHeadAttn(nn.Module):
if
self
.
pre_lnorm
:
##### residual connection
output
=
h
+
attn_out
output
s
=
[
h
+
attn_out
]
else
:
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
h
+
attn_out
)
output
s
=
[
self
.
layer_norm
(
h
+
attn_out
)
]
return
output
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
):
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
...
...
@@ -506,7 +522,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
qlen
,
rlen
,
bsz
=
w
.
size
(
0
),
r
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
...
...
@@ -561,6 +577,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
...
...
@@ -574,18 +594,21 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
if
self
.
pre_lnorm
:
##### residual connection
output
=
w
+
attn_out
output
s
=
[
w
+
attn_out
]
else
:
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
w
+
attn_out
)
output
s
=
[
self
.
layer_norm
(
w
+
attn_out
)
]
return
output
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
RelLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
w
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
w
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
...
...
@@ -646,6 +669,9 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
...
...
@@ -659,12 +685,17 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
if
self
.
pre_lnorm
:
##### residual connection
output
=
w
+
attn_out
output
s
=
[
w
+
attn_out
]
else
:
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
w
+
attn_out
)
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
return
output
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
...
...
@@ -674,13 +705,15 @@ class DecoderLayer(nn.Module):
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
attn_
output
s
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_
output
=
self
.
pos_ff
(
attn_
output
s
[
0
]
)
return
output
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
...
@@ -692,14 +725,16 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_
output
s
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
mems
=
mems
,
head_mask
=
head_mask
)
ff_
output
=
self
.
pos_ff
(
attn_
output
s
[
0
]
)
return
output
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
...
@@ -711,14 +746,17 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
r
,
attn_
output
s
=
self
.
dec_attn
(
dec_inp
,
r
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
return
output
class
AdaptiveEmbedding
(
nn
.
Module
):
...
...
@@ -791,13 +829,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
"""
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
...
...
@@ -894,6 +928,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
"""
def
__init__
(
self
,
config
):
super
(
TransfoXLModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
n_token
=
config
.
n_token
self
.
d_embed
=
config
.
d_embed
...
...
@@ -928,7 +965,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
)
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
elif
config
.
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
config
.
n_layer
):
...
...
@@ -938,7 +976,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
)
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
elif
config
.
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
config
.
n_layer
):
...
...
@@ -947,7 +986,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
)
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
self
.
same_length
=
config
.
same_length
...
...
@@ -965,17 +1005,21 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
elif
self
.
attn_type
==
3
:
# absolute deeper SA
self
.
r_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
apply
(
self
.
init_weights
)
def
backward_compatible
(
self
):
self
.
sample_softmax
=
-
1
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
self
.
tgt_len
=
tgt_len
self
.
mem_len
=
mem_len
self
.
ext_len
=
ext_len
def
_prune_heads
(
self
,
heads
):
logger
.
info
(
"Head pruning is not implemented for Transformer-XL model"
)
pass
def
init_mems
(
self
,
data
):
if
self
.
mem_len
>
0
:
mems
=
[]
...
...
@@ -1012,9 +1056,24 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return
new_mems
def
_forward
(
self
,
dec_inp
,
mems
=
None
):
def
_forward
(
self
,
dec_inp
,
mems
=
None
,
head_mask
=
None
):
qlen
,
bsz
=
dec_inp
.
size
()
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
)
head_mask
=
head_mask
.
expand
(
self
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
unsqueeze
(
1
)
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
n_layer
word_emb
=
self
.
word_emb
(
dec_inp
)
mlen
=
mems
[
0
].
size
(
0
)
if
mems
is
not
None
else
0
...
...
@@ -1033,6 +1092,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
hids
=
[]
attentions
=
[]
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
...
...
@@ -1046,7 +1106,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
pos_emb
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
layer_outputs
=
layer
(
core_out
,
pos_emb
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
...
...
@@ -1058,8 +1122,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
layer_outputs
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
...
...
@@ -1074,8 +1142,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
...
...
@@ -1093,16 +1164,30 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
return
core_out
,
new_mems
def
forward
(
self
,
input_ids
,
mems
=
None
):
# We transpose back here to shape [bsz, len, hidden_dim]
outputs
=
[
core_out
.
transpose
(
0
,
1
).
contiguous
(),
new_mems
]
if
self
.
output_hidden_states
:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids
.
append
(
core_out
)
hids
=
list
(
t
.
transpose
(
0
,
1
).
contiguous
()
for
t
in
hids
)
outputs
.
append
(
hids
)
if
self
.
output_attentions
:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
):
""" Params:
input_ids :: [bsz, len]
mems :: optional mems from previous forwar passes (or init_mems)
...
...
@@ -1122,11 +1207,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
last_hidden
,
new_mem
s
=
self
.
_forward
(
input_ids
,
mems
=
mems
)
output
s
=
self
.
_forward
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
# We transpose back here to shape [bsz, len, hidden_dim]
last_hidden
=
last_hidden
.
transpose
(
0
,
1
).
contiguous
()
return
(
last_hidden
,
new_mems
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
class
TransfoXLLMHeadModel
(
TransfoXLPreTrainedModel
):
...
...
@@ -1218,7 +1301,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def
init_mems
(
self
,
data
):
return
self
.
transformer
.
init_mems
(
data
)
def
forward
(
self
,
input_ids
,
labels
=
None
,
mems
=
None
):
def
forward
(
self
,
input_ids
,
labels
=
None
,
mems
=
None
,
head_mask
=
None
):
""" Params:
input_ids :: [bsz, len]
labels :: [bsz, len]
...
...
@@ -1235,19 +1318,26 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
bsz
=
input_ids
.
size
(
0
)
tgt_len
=
input_ids
.
size
(
1
)
last_hidden
,
new_mem
s
=
self
.
transformer
(
input_ids
,
mems
)
transformer_output
s
=
self
.
transformer
(
input_ids
,
mems
,
head_mask
)
last_hidden
=
transformer_outputs
[
0
]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
outputs
=
transformer_outputs
[
1
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
assert
self
.
config
.
tie_weight
logit
=
sample_logits
(
self
.
transformer
.
word_emb
,
self
.
out_layer
.
bias
,
labels
,
pred_hid
,
self
.
sampler
)
softmax_output
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
outputs
=
[
softmax_output
]
+
outputs
if
labels
is
not
None
:
# TODO: This is not implemented
raise
NotImplementedError
else
:
softmax_output
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
labels
)
if
labels
is
None
:
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
,
-
1
)
outputs
=
[
softmax_output
]
+
outputs
else
:
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
)
outputs
=
[
softmax_output
,
None
]
+
outputs
# We transpose back
return
(
softmax_output
,
new_mems
)
return
outputs
# (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
pytorch_pretrained_bert/modeling_xlm.py
View file @
1484d67d
...
...
@@ -73,6 +73,7 @@ class XLMConfig(PretrainedConfig):
def
__init__
(
self
,
vocab_size_or_config_json_file
,
causal
=
True
,
d_model
=
1024
,
n_layer
=
24
,
n_head
=
16
,
...
...
@@ -145,6 +146,7 @@ class XLMConfig(PretrainedConfig):
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
n_token
=
vocab_size_or_config_json_file
self
.
causal
=
causal
self
.
d_model
=
d_model
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
...
...
@@ -396,7 +398,6 @@ class XLMPreTrainedModel(PreTrainedModel):
"""
config_class
=
XLMConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
None
base_model_prefix
=
"xlm"
...
...
@@ -429,7 +430,7 @@ class XLMModel(XLMPreTrainedModel):
'hidden_dim'
,
'dropout'
,
'attention_dropout'
,
'asm'
,
'asm_cutoffs'
,
'asm_div_value'
]
def
__init__
(
self
,
params
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
#, dico, is_encoder, with_output):
def
__init__
(
self
,
params
,
output_attentions
=
False
,
output_hidden_states
=
False
):
#, dico, is_encoder, with_output):
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
Paper: https://arxiv.org/abs/1901.07291
Original code: https://github.com/facebookresearch/XLM
...
...
@@ -483,11 +484,13 @@ class XLMModel(XLMPreTrainedModel):
"""
super
(
XLMModel
,
self
).
__init__
(
params
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
# encoder / decoder, output layer
# self.is_encoder = is_encoder
# self.is_decoder = not is_encoder
# self.with_output = with_output
self
.
causal
=
params
.
causal
# dictionary / languages
self
.
n_langs
=
params
.
n_langs
...
...
@@ -536,63 +539,45 @@ class XLMModel(XLMPreTrainedModel):
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
dropout
=
self
.
dropout
,
gelu_activation
=
params
.
gelu_activation
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
# output layer
# if self.with_output:
# self.pred_layer = PredLayer(params)
# if params.share_inout_emb:
# self.pred_layer.proj.weight = self.embeddings.weight
# def forward(self, mode, **kwargs):
# """
# Forward function with different forward modes.
# ### Small hack to handle PyTorch distributed.
# """
# if mode == 'fwd':
# return self.fwd(**kwargs)
# elif mode == 'predict':
# return self.predict(**kwargs)
# else:
# raise Exception("Unknown mode: %s" % mode)
def
forward
(
self
,
x
,
lengths
,
causal
,
src_enc
=
None
,
src_len
=
None
,
positions
=
None
,
langs
=
None
,
cache
=
None
):
def
forward
(
self
,
x
,
lengths
,
positions
=
None
,
langs
=
None
,
cache
=
None
,
head_mask
=
None
):
# src_enc=None, src_len=None,
"""
Inputs:
`x` LongTensor(slen
, bs
), containing word indices
`x` LongTensor(
bs,
slen), containing word indices
`lengths` LongTensor(bs), containing the length of each sentence
`causal` Boolean, if True, the attention is only done over previous hidden states
`positions` LongTensor(slen
, bs
), containing word positions
`langs` LongTensor(slen
, bs
), containing language IDs
`positions` LongTensor(
bs,
slen), containing word positions
`langs` LongTensor(
bs,
slen), containing language IDs
"""
# lengths = (x != self.pad_index).float().sum(dim=1)
# mask = x != self.pad_index
# check inputs
slen
,
bs
=
x
.
size
()
bs
,
slen
=
x
.
size
()
assert
lengths
.
size
(
0
)
==
bs
assert
lengths
.
max
().
item
()
<=
slen
x
=
x
.
transpose
(
0
,
1
)
# batch size as dimension 0
assert
(
src_enc
is
None
)
==
(
src_len
is
None
)
if
src_enc
is
not
None
:
assert
self
.
is_decoder
assert
src_enc
.
size
(
0
)
==
bs
#
x = x.transpose(0, 1) # batch size as dimension 0
#
assert (src_enc is None) == (src_len is None)
#
if src_enc is not None:
#
assert self.is_decoder
#
assert src_enc.size(0) == bs
# generate masks
mask
,
attn_mask
=
get_masks
(
slen
,
lengths
,
causal
)
if
self
.
is_decoder
and
src_enc
is
not
None
:
src_mask
=
torch
.
arange
(
src_len
.
max
(),
dtype
=
torch
.
long
,
device
=
lengths
.
device
)
<
src_len
[:,
None
]
mask
,
attn_mask
=
get_masks
(
slen
,
lengths
,
self
.
causal
)
#
if self.is_decoder and src_enc is not None:
#
src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# positions
if
positions
is
None
:
positions
=
x
.
new
(
slen
).
long
()
positions
=
torch
.
arange
(
slen
,
out
=
positions
).
unsqueeze
(
0
)
else
:
assert
positions
.
size
()
==
(
slen
,
bs
)
positions
=
positions
.
transpose
(
0
,
1
)
assert
positions
.
size
()
==
(
bs
,
slen
)
#
(slen, bs)
#
positions = positions.transpose(0, 1)
# langs
if
langs
is
not
None
:
assert
langs
.
size
()
==
(
slen
,
bs
)
langs
=
langs
.
transpose
(
0
,
1
)
assert
langs
.
size
()
==
(
bs
,
slen
)
#
(slen, bs)
#
langs = langs.transpose(0, 1)
# do not recompute cached elements
if
cache
is
not
None
:
...
...
@@ -614,620 +599,50 @@ class XLMModel(XLMPreTrainedModel):
tensor
*=
mask
.
unsqueeze
(
-
1
).
to
(
tensor
.
dtype
)
# transformer layers
hidden_states
=
[]
attentions
=
[]
for
i
in
range
(
self
.
n_layers
):
if
self
.
output_hidden_states
:
hidden_states
.
append
(
tensor
)
# self attention
attn
=
self
.
attentions
[
i
](
tensor
,
attn_mask
,
cache
=
cache
)
attn_outputs
=
self
.
attentions
[
i
](
tensor
,
attn_mask
,
cache
=
cache
,
head_mask
=
head_mask
[
i
])
attn
=
attn_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
attn_outputs
[
1
])
attn
=
F
.
dropout
(
attn
,
p
=
self
.
dropout
,
training
=
self
.
training
)
tensor
=
tensor
+
attn
tensor
=
self
.
layer_norm1
[
i
](
tensor
)
# encoder attention (for decoder only)
if
self
.
is_decoder
and
src_enc
is
not
None
:
attn
=
self
.
encoder_attn
[
i
](
tensor
,
src_mask
,
kv
=
src_enc
,
cache
=
cache
)
attn
=
F
.
dropout
(
attn
,
p
=
self
.
dropout
,
training
=
self
.
training
)
tensor
=
tensor
+
attn
tensor
=
self
.
layer_norm15
[
i
](
tensor
)
#
if self.is_decoder and src_enc is not None:
#
attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
#
attn = F.dropout(attn, p=self.dropout, training=self.training)
#
tensor = tensor + attn
#
tensor = self.layer_norm15[i](tensor)
# FFN
tensor
=
tensor
+
self
.
ffns
[
i
](
tensor
)
tensor
=
self
.
layer_norm2
[
i
](
tensor
)
tensor
*=
mask
.
unsqueeze
(
-
1
).
to
(
tensor
.
dtype
)
# Add last hidden state
if
self
.
output_hidden_states
:
hidden_states
.
append
(
tensor
)
# update cache length
if
cache
is
not
None
:
cache
[
'slen'
]
+=
tensor
.
size
(
1
)
# move back sequence length to dimension 0
tensor
=
tensor
.
transpose
(
0
,
1
)
return
tensor
def
predict
(
self
,
tensor
,
pred_mask
,
y
,
get_scores
):
"""
Given the last hidden state, compute word scores and/or the loss.
`pred_mask` is a ByteTensor of shape (slen, bs), filled with 1 when
we need to predict a word
`y` is a LongTensor of shape (pred_mask.sum(),)
`get_scores` is a boolean specifying whether we need to return scores
"""
masked_tensor
=
tensor
[
pred_mask
.
unsqueeze
(
-
1
).
expand_as
(
tensor
)].
view
(
-
1
,
self
.
dim
)
scores
,
loss
=
self
.
pred_layer
(
masked_tensor
,
y
,
get_scores
)
return
scores
,
loss
def
generate
(
self
,
src_enc
,
src_len
,
tgt_lang_id
,
max_len
=
200
,
sample_temperature
=
None
):
"""
Decode a sentence given initial start.
`x`:
- LongTensor(bs, slen)
<EOS> W1 W2 W3 <EOS> <PAD>
<EOS> W1 W2 W3 W4 <EOS>
`lengths`:
- LongTensor(bs) [5, 6]
`positions`:
- False, for regular "arange" positions (LM)
- True, to reset positions from the new generation (MT)
`langs`:
- must be None if the model only supports one language
- lang_id if only one language is involved (LM)
- (lang_id1, lang_id2) if two languages are involved (MT)
"""
# input batch
bs
=
len
(
src_len
)
assert
src_enc
.
size
(
0
)
==
bs
# generated sentences
generated
=
src_len
.
new
(
max_len
,
bs
)
# upcoming output
generated
.
fill_
(
self
.
pad_index
)
# fill upcoming ouput with <PAD>
generated
[
0
].
fill_
(
self
.
eos_index
)
# we use <EOS> for <BOS> everywhere
# positions
positions
=
src_len
.
new
(
max_len
).
long
()
positions
=
torch
.
arange
(
max_len
,
out
=
positions
).
unsqueeze
(
1
).
expand
(
max_len
,
bs
)
# language IDs
langs
=
src_len
.
new
(
max_len
).
long
().
fill_
(
tgt_lang_id
)
langs
=
langs
.
unsqueeze
(
1
).
expand
(
max_len
,
bs
)
# current position / max lengths / length of generated sentences / unfinished sentences
cur_len
=
1
gen_len
=
src_len
.
clone
().
fill_
(
1
)
unfinished_sents
=
src_len
.
clone
().
fill_
(
1
)
# cache compute states
cache
=
{
'slen'
:
0
}
while
cur_len
<
max_len
:
# compute word scores
tensor
=
self
.
forward
(
'fwd'
,
x
=
generated
[:
cur_len
],
lengths
=
gen_len
,
positions
=
positions
[:
cur_len
],
langs
=
langs
[:
cur_len
],
causal
=
True
,
src_enc
=
src_enc
,
src_len
=
src_len
,
cache
=
cache
)
assert
tensor
.
size
()
==
(
1
,
bs
,
self
.
dim
)
tensor
=
tensor
.
data
[
-
1
,
:,
:]
# (bs, dim)
scores
=
self
.
pred_layer
.
get_scores
(
tensor
)
# (bs, n_words)
# select next words: sample or greedy
if
sample_temperature
is
None
:
next_words
=
torch
.
topk
(
scores
,
1
)[
1
].
squeeze
(
1
)
else
:
next_words
=
torch
.
multinomial
(
F
.
softmax
(
scores
/
sample_temperature
,
dim
=
1
),
1
).
squeeze
(
1
)
assert
next_words
.
size
()
==
(
bs
,)
# update generations / lengths / finished sentences / current length
generated
[
cur_len
]
=
next_words
*
unfinished_sents
+
self
.
pad_index
*
(
1
-
unfinished_sents
)
gen_len
.
add_
(
unfinished_sents
)
unfinished_sents
.
mul_
(
next_words
.
ne
(
self
.
eos_index
).
long
())
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if
unfinished_sents
.
max
()
==
0
:
break
# add <EOS> to unfinished sentences
if
cur_len
==
max_len
:
generated
[
-
1
].
masked_fill_
(
unfinished_sents
.
byte
(),
self
.
eos_index
)
# sanity check
assert
(
generated
==
self
.
eos_index
).
sum
()
==
2
*
bs
return
generated
[:
cur_len
],
gen_len
def
generate_beam
(
self
,
src_enc
,
src_len
,
tgt_lang_id
,
beam_size
,
length_penalty
,
early_stopping
,
max_len
=
200
):
"""
Decode a sentence given initial start.
`x`:
- LongTensor(bs, slen)
<EOS> W1 W2 W3 <EOS> <PAD>
<EOS> W1 W2 W3 W4 <EOS>
`lengths`:
- LongTensor(bs) [5, 6]
`positions`:
- False, for regular "arange" positions (LM)
- True, to reset positions from the new generation (MT)
`langs`:
- must be None if the model only supports one language
- lang_id if only one language is involved (LM)
- (lang_id1, lang_id2) if two languages are involved (MT)
"""
# check inputs
assert
src_enc
.
size
(
0
)
==
src_len
.
size
(
0
)
assert
beam_size
>=
1
# batch size / number of words
bs
=
len
(
src_len
)
n_words
=
self
.
n_words
# tensor = tensor.transpose(0, 1)
# expand to beam size the source latent representations / source lengths
src_enc
=
src_enc
.
unsqueeze
(
1
).
expand
((
bs
,
beam_size
)
+
src_enc
.
shape
[
1
:]).
contiguous
().
view
((
bs
*
beam_size
,)
+
src_enc
.
shape
[
1
:])
src_len
=
src_len
.
unsqueeze
(
1
).
expand
(
bs
,
beam_size
).
contiguous
().
view
(
-
1
)
# generated sentences (batch with beam current hypotheses)
generated
=
src_len
.
new
(
max_len
,
bs
*
beam_size
)
# upcoming output
generated
.
fill_
(
self
.
pad_index
)
# fill upcoming ouput with <PAD>
generated
[
0
].
fill_
(
self
.
eos_index
)
# we use <EOS> for <BOS> everywhere
# generated hypotheses
generated_hyps
=
[
BeamHypotheses
(
beam_size
,
max_len
,
length_penalty
,
early_stopping
)
for
_
in
range
(
bs
)]
# positions
positions
=
src_len
.
new
(
max_len
).
long
()
positions
=
torch
.
arange
(
max_len
,
out
=
positions
).
unsqueeze
(
1
).
expand_as
(
generated
)
# language IDs
langs
=
positions
.
clone
().
fill_
(
tgt_lang_id
)
# scores for each sentence in the beam
beam_scores
=
src_enc
.
new
(
bs
,
beam_size
).
fill_
(
0
)
beam_scores
[:,
1
:]
=
-
1e9
beam_scores
=
beam_scores
.
view
(
-
1
)
# current position
cur_len
=
1
# cache compute states
cache
=
{
'slen'
:
0
}
# done sentences
done
=
[
False
for
_
in
range
(
bs
)]
while
cur_len
<
max_len
:
# compute word scores
tensor
=
self
.
forward
(
'fwd'
,
x
=
generated
[:
cur_len
],
lengths
=
src_len
.
new
(
bs
*
beam_size
).
fill_
(
cur_len
),
positions
=
positions
[:
cur_len
],
langs
=
langs
[:
cur_len
],
causal
=
True
,
src_enc
=
src_enc
,
src_len
=
src_len
,
cache
=
cache
)
assert
tensor
.
size
()
==
(
1
,
bs
*
beam_size
,
self
.
dim
)
tensor
=
tensor
.
data
[
-
1
,
:,
:]
# (bs * beam_size, dim)
scores
=
self
.
pred_layer
.
get_scores
(
tensor
)
# (bs * beam_size, n_words)
scores
=
F
.
log_softmax
(
scores
,
dim
=-
1
)
# (bs * beam_size, n_words)
assert
scores
.
size
()
==
(
bs
*
beam_size
,
n_words
)
# select next words with scores
_scores
=
scores
+
beam_scores
[:,
None
].
expand_as
(
scores
)
# (bs * beam_size, n_words)
_scores
=
_scores
.
view
(
bs
,
beam_size
*
n_words
)
# (bs, beam_size * n_words)
next_scores
,
next_words
=
torch
.
topk
(
_scores
,
2
*
beam_size
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
assert
next_scores
.
size
()
==
next_words
.
size
()
==
(
bs
,
2
*
beam_size
)
# next batch beam content
# list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
next_batch_beam
=
[]
# for each sentence
for
sent_id
in
range
(
bs
):
# if we are done with this sentence
done
[
sent_id
]
=
done
[
sent_id
]
or
generated_hyps
[
sent_id
].
is_done
(
next_scores
[
sent_id
].
max
().
item
())
if
done
[
sent_id
]:
next_batch_beam
.
extend
([(
0
,
self
.
pad_index
,
0
)]
*
beam_size
)
# pad the batch
continue
# next sentence beam content
next_sent_beam
=
[]
# next words for this sentence
for
idx
,
value
in
zip
(
next_words
[
sent_id
],
next_scores
[
sent_id
]):
# get beam and word IDs
beam_id
=
idx
//
n_words
word_id
=
idx
%
n_words
# end of sentence, or next word
if
word_id
==
self
.
eos_index
or
cur_len
+
1
==
max_len
:
generated_hyps
[
sent_id
].
add
(
generated
[:
cur_len
,
sent_id
*
beam_size
+
beam_id
].
clone
(),
value
.
item
())
else
:
next_sent_beam
.
append
((
value
,
word_id
,
sent_id
*
beam_size
+
beam_id
))
# the beam for next step is full
if
len
(
next_sent_beam
)
==
beam_size
:
break
# update next beam content
assert
len
(
next_sent_beam
)
==
0
if
cur_len
+
1
==
max_len
else
beam_size
if
len
(
next_sent_beam
)
==
0
:
next_sent_beam
=
[(
0
,
self
.
pad_index
,
0
)]
*
beam_size
# pad the batch
next_batch_beam
.
extend
(
next_sent_beam
)
assert
len
(
next_batch_beam
)
==
beam_size
*
(
sent_id
+
1
)
# sanity check / prepare next batch
assert
len
(
next_batch_beam
)
==
bs
*
beam_size
beam_scores
=
beam_scores
.
new
([
x
[
0
]
for
x
in
next_batch_beam
])
beam_words
=
generated
.
new
([
x
[
1
]
for
x
in
next_batch_beam
])
beam_idx
=
src_len
.
new
([
x
[
2
]
for
x
in
next_batch_beam
])
# re-order batch and internal states
generated
=
generated
[:,
beam_idx
]
generated
[
cur_len
]
=
beam_words
for
k
in
cache
.
keys
():
if
k
!=
'slen'
:
cache
[
k
]
=
(
cache
[
k
][
0
][
beam_idx
],
cache
[
k
][
1
][
beam_idx
])
# update current length
cur_len
=
cur_len
+
1
# stop when we are done with each sentence
if
all
(
done
):
break
# visualize hypotheses
# print([len(x) for x in generated_hyps], cur_len)
# globals().update( locals() );
# !import code; code.interact(local=vars())
# for ii in range(bs):
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
# print("")
# select the best hypotheses
tgt_len
=
src_len
.
new
(
bs
)
best
=
[]
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
best_hyp
=
max
(
hypotheses
.
hyp
,
key
=
lambda
x
:
x
[
0
])[
1
]
tgt_len
[
i
]
=
len
(
best_hyp
)
+
1
# +1 for the <EOS> symbol
best
.
append
(
best_hyp
)
# generate target batch
decoded
=
src_len
.
new
(
tgt_len
.
max
().
item
(),
bs
).
fill_
(
self
.
pad_index
)
for
i
,
hypo
in
enumerate
(
best
):
decoded
[:
tgt_len
[
i
]
-
1
,
i
]
=
hypo
decoded
[
tgt_len
[
i
]
-
1
,
i
]
=
self
.
eos_index
# sanity check
assert
(
decoded
==
self
.
eos_index
).
sum
()
==
2
*
bs
return
decoded
,
tgt_len
class
XLMModel
(
XLMPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLMModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
mem_len
=
config
.
mem_len
self
.
reuse_len
=
config
.
reuse_len
self
.
d_model
=
config
.
d_model
self
.
same_length
=
config
.
same_length
self
.
attn_type
=
config
.
attn_type
self
.
bi_data
=
config
.
bi_data
self
.
clamp_len
=
config
.
clamp_len
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
layer
=
XLMLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
layer
]
def
create_mask
(
self
,
qlen
,
mlen
):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
same_length=False: same_length=True:
<mlen > < qlen > <mlen > < qlen >
^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask
=
torch
.
ones
([
qlen
,
qlen
])
mask_up
=
torch
.
triu
(
attn_mask
,
diagonal
=
1
)
attn_mask_pad
=
torch
.
zeros
([
qlen
,
mlen
])
ret
=
torch
.
cat
([
attn_mask_pad
,
mask_up
],
dim
=
1
)
if
self
.
same_length
:
mask_lo
=
torch
.
tril
(
attn_mask
,
diagonal
=-
1
)
ret
=
torch
.
cat
([
ret
[:,
:
qlen
]
+
mask_lo
,
ret
[:,
qlen
:]],
dim
=
1
)
ret
=
ret
.
to
(
next
(
self
.
parameters
()))
return
ret
def
cache_mem
(
self
,
curr_out
,
prev_mem
):
"""cache hidden states into memory."""
if
self
.
mem_len
is
None
or
self
.
mem_len
==
0
:
return
None
else
:
if
self
.
reuse_len
is
not
None
and
self
.
reuse_len
>
0
:
curr_out
=
curr_out
[:
self
.
reuse_len
]
if
prev_mem
is
None
:
new_mem
=
curr_out
[
-
self
.
mem_len
:]
else
:
new_mem
=
torch
.
cat
([
prev_mem
,
curr_out
],
dim
=
0
)[
-
self
.
mem_len
:]
return
new_mem
.
detach
()
@
staticmethod
def
positional_embedding
(
pos_seq
,
inv_freq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
einsum
(
'i,d->id'
,
pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
torch
.
sin
(
sinusoid_inp
),
torch
.
cos
(
sinusoid_inp
)],
dim
=-
1
)
pos_emb
=
pos_emb
[:,
None
,
:]
if
bsz
is
not
None
:
pos_emb
=
pos_emb
.
expand
(
-
1
,
bsz
,
-
1
)
return
pos_emb
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
):
"""create relative positional encoding."""
freq_seq
=
torch
.
arange
(
0
,
self
.
d_model
,
2.0
,
dtype
=
torch
.
float
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
d_model
))
if
self
.
attn_type
==
'bi'
:
# beg, end = klen - 1, -qlen
beg
,
end
=
klen
,
-
qlen
elif
self
.
attn_type
==
'uni'
:
# beg, end = klen - 1, -1
beg
,
end
=
klen
,
-
1
else
:
raise
ValueError
(
'Unknown `attn_type` {}.'
.
format
(
self
.
attn_type
))
if
self
.
bi_data
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
torch
.
float
)
bwd_pos_seq
=
torch
.
arange
(
-
beg
,
-
end
,
1.0
,
dtype
=
torch
.
float
)
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
bwd_pos_seq
=
bwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
if
bsz
is
not
None
:
fwd_pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
//
2
)
bwd_pos_emb
=
self
.
positional_embedding
(
bwd_pos_seq
,
inv_freq
,
bsz
//
2
)
else
:
fwd_pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
)
bwd_pos_emb
=
self
.
positional_embedding
(
bwd_pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
fwd_pos_emb
,
bwd_pos_emb
],
dim
=
1
)
else
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
)
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
)
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
return
pos_emb
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
but with 1 for real tokens and 0 for padding.
Added for easy compatibility with the XLM model (which uses this negative masking).
You can only uses one among `input_mask` and `attention_mask`
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
# the original code for XLM uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
inp_k
=
inp_k
.
transpose
(
0
,
1
).
contiguous
()
token_type_ids
=
token_type_ids
.
transpose
(
0
,
1
).
contiguous
()
if
token_type_ids
is
not
None
else
None
input_mask
=
input_mask
.
transpose
(
0
,
1
).
contiguous
()
if
input_mask
is
not
None
else
None
attention_mask
=
attention_mask
.
transpose
(
0
,
1
).
contiguous
()
if
attention_mask
is
not
None
else
None
perm_mask
=
perm_mask
.
permute
(
1
,
2
,
0
).
contiguous
()
if
perm_mask
is
not
None
else
None
target_mapping
=
target_mapping
.
permute
(
1
,
2
,
0
).
contiguous
()
if
target_mapping
is
not
None
else
None
inp_q
=
inp_q
.
transpose
(
0
,
1
).
contiguous
()
if
inp_q
is
not
None
else
None
qlen
,
bsz
=
inp_k
.
shape
[
0
],
inp_k
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
dtype_float
=
next
(
self
.
parameters
()).
dtype
device
=
next
(
self
.
parameters
()).
device
##### Attention mask
# causal attention mask
if
self
.
attn_type
==
'uni'
:
attn_mask
=
self
.
create_mask
(
qlen
,
mlen
)
attn_mask
=
attn_mask
[:,
:,
None
,
None
]
elif
self
.
attn_type
==
'bi'
:
attn_mask
=
None
else
:
raise
ValueError
(
'Unsupported attention type: {}'
.
format
(
self
.
attn_type
))
# data mask: input mask & perm mask
assert
input_mask
is
None
or
attention_mask
is
None
,
"You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatbility with XLM). Please choose one."
if
input_mask
is
None
and
attention_mask
is
not
None
:
input_mask
=
1.0
-
attention_mask
if
input_mask
is
not
None
and
perm_mask
is
not
None
:
data_mask
=
input_mask
[
None
]
+
perm_mask
elif
input_mask
is
not
None
and
perm_mask
is
None
:
data_mask
=
input_mask
[
None
]
elif
input_mask
is
None
and
perm_mask
is
not
None
:
data_mask
=
perm_mask
else
:
data_mask
=
None
if
data_mask
is
not
None
:
# all mems can be attended to
mems_mask
=
torch
.
zeros
([
data_mask
.
shape
[
0
],
mlen
,
bsz
]).
to
(
data_mask
)
data_mask
=
torch
.
cat
([
mems_mask
,
data_mask
],
dim
=
1
)
if
attn_mask
is
None
:
attn_mask
=
data_mask
[:,
:,
:,
None
]
else
:
attn_mask
+=
data_mask
[:,
:,
:,
None
]
if
attn_mask
is
not
None
:
attn_mask
=
(
attn_mask
>
0
).
to
(
dtype_float
)
if
attn_mask
is
not
None
:
non_tgt_mask
=
-
torch
.
eye
(
qlen
).
to
(
attn_mask
)
non_tgt_mask
=
torch
.
cat
([
torch
.
zeros
([
qlen
,
mlen
]).
to
(
attn_mask
),
non_tgt_mask
],
dim
=-
1
)
non_tgt_mask
=
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
).
to
(
attn_mask
)
else
:
non_tgt_mask
=
None
##### Word embeddings and prepare h & g hidden states
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
self
.
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
-
1
)
else
:
inp_q_ext
=
inp_q
[:,
:,
None
]
word_emb_q
=
inp_q_ext
*
self
.
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
output_g
=
self
.
dropout
(
word_emb_q
)
else
:
output_g
=
None
##### Segment embedding
if
token_type_ids
is
not
None
:
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad
=
torch
.
zeros
([
mlen
,
bsz
],
dtype
=
torch
.
long
,
device
=
device
)
cat_ids
=
torch
.
cat
([
mem_pad
,
token_type_ids
],
dim
=
0
)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
(
token_type_ids
[:,
None
]
!=
cat_ids
[
None
,
:]).
long
()
seg_mat
=
F
.
one_hot
(
seg_mat
,
num_classes
=
2
).
to
(
dtype_float
)
else
:
seg_mat
=
None
##### Positional encoding
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
dropout
(
pos_emb
)
##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [n_layer x num_heads]
# and head_mask is converted to shape [n_layer x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
new_mems
=
[]
if
mems
is
None
:
mems
=
[
None
]
*
len
(
self
.
layer
)
hidden_states
=
[]
attentions
=
[]
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
# Save hidden_states
if
output_g
is
None
:
hidden_states
.
append
(
output_h
)
else
:
hidden_states
.
append
((
output_h
,
output_g
))
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
# Save last hidden_state
if
output_g
is
None
:
hidden_states
.
append
(
output_h
)
else
:
hidden_states
.
append
((
output_h
,
output_g
))
# Select the right output and add dropout
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output
=
output
.
permute
(
1
,
0
,
2
).
contiguous
()
if
output_g
is
None
:
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
else
:
hidden_states
=
[
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
]
# Build the list of outputs
outputs
=
[
output
,
new_mems
]
if
self
.
output_attentions
:
outputs
.
append
(
attentions
)
outputs
=
[
tensor
]
if
self
.
output_hidden_states
:
outputs
.
append
(
hidden_states
)
return
outputs
if
self
.
output_attentions
:
outputs
.
append
(
attentions
)
return
outputs
# outputs, (hidden_states), (attentions)
class
XLMPredLayer
(
nn
.
Module
):
...
...
@@ -1275,8 +690,11 @@ class XLMPredLayer(nn.Module):
return
self
.
proj
.
log_prob
(
x
)
if
self
.
asm
else
self
.
proj
(
x
)
class
XLMLMHeadModel
(
XLMPreTrainedModel
):
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
class
XLMWithLMHeadModel
(
XLMPreTrainedModel
):
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
Paper: https://arxiv.org/abs/1901.07291
Original code: https://github.com/facebookresearch/XLM
Params:
`config`: a XLMConfig class instance with the configuration to build a new model
...
...
@@ -1285,36 +703,29 @@ class XLMLMHeadModel(XLMPreTrainedModel):
This can be used to compute head importance metrics. Default: False
Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
attention_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see XLM paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length,
d_model
],
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length,
hidden_size
],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length,
d_model
],
`pooled_output`: a torch.FloatTensor of size [batch_size,
d_model
] which is the output of a
to the last attention block of shape [batch_size, sequence_length,
hidden_size
],
`pooled_output`: a torch.FloatTensor of size [batch_size,
hidden_size
] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see XLM's paper).
...
...
@@ -1325,8 +736,8 @@ class XLMLMHeadModel(XLMPreTrainedModel):
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000,
d_model
=768,
n_layer=12, num_attention_heads=12, intermediate_size=3072)
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000,
hidden_size
=768,
num_hidde
n_layer
s
=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLMModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
...
...
@@ -1341,9 +752,7 @@ class XLMLMHeadModel(XLMPreTrainedModel):
self
.
same_length
=
config
.
same_length
self
.
transformer
=
XLMModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
# Tie weights
self
.
pred_layer
=
XLMPredLayer
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
tie_weights
()
...
...
@@ -1351,10 +760,9 @@ class XLMLMHeadModel(XLMPreTrainedModel):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
"""
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_
embedding
.
weight
self
.
pred_layer
.
proj
.
weight
=
self
.
transformer
.
embedding
s
.
weight
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
def
forward
(
self
,
x
,
lengths
,
positions
=
None
,
langs
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
"""
Args:
...
...
@@ -1382,11 +790,10 @@ class XLMLMHeadModel(XLMPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
transformer_outputs
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
x
,
lengths
,
positions
=
positions
,
langs
=
langs
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
logits
=
self
.
lm_loss
(
output
)
logits
=
self
.
pred_layer
(
output
,
labels
)
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
1484d67d
...
...
@@ -198,7 +198,7 @@ class XLNetConfig(PretrainedConfig):
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
vocab_size_or_config_json_file
,
vocab_size_or_config_json_file
=
32000
,
d_model
=
1024
,
n_layer
=
24
,
n_head
=
16
,
...
...
@@ -221,7 +221,12 @@ class XLNetConfig(PretrainedConfig):
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
,
finetuning_task
=
None
):
finetuning_task
=
None
,
num_labels
=
2
,
summary_type
=
"last"
,
use_proj
=
True
,
**
kwargs
):
"""Constructs XLNetConfig.
Args:
...
...
@@ -265,6 +270,8 @@ class XLNetConfig(PretrainedConfig):
same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
super
(
XLNetConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
...
...
@@ -297,7 +304,11 @@ class XLNetConfig(PretrainedConfig):
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
self
.
finetuning_task
=
finetuning_task
self
.
num_labels
=
num_labels
self
.
summary_type
=
summary_type
self
.
use_proj
=
use_proj
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -323,9 +334,10 @@ except ImportError:
return
self
.
weight
*
x
+
self
.
bias
class
XLNetRelativeAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
XLNetRelativeAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
config
.
output_attentions
if
config
.
d_model
%
config
.
n_head
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
...
...
@@ -533,10 +545,9 @@ class XLNetFeedForward(nn.Module):
return
output
class
XLNetLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
):
def
__init__
(
self
,
config
):
super
(
XLNetLayer
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
rel_attn
=
XLNetRelativeAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
rel_attn
=
XLNetRelativeAttention
(
config
)
self
.
ff
=
XLNetFeedForward
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
...
...
@@ -562,7 +573,6 @@ class XLNetPreTrainedModel(PreTrainedModel):
"""
config_class
=
XLNetConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_xlnet
base_model_prefix
=
"transformer"
...
...
@@ -589,10 +599,10 @@ class XLNetPreTrainedModel(PreTrainedModel):
class
XLNetModel
(
XLNetPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
XLNetModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
mem_len
=
config
.
mem_len
self
.
reuse_len
=
config
.
reuse_len
...
...
@@ -601,25 +611,17 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
attn_type
=
config
.
attn_type
self
.
bi_data
=
config
.
bi_data
self
.
clamp_len
=
config
.
clamp_len
self
.
n_layer
=
config
.
n_layer
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
)
layer
=
XLNetLayer
(
config
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
layer
]
def
_prune_heads
(
self
,
heads_to_prune
):
logger
.
info
(
"Head pruning is not implemented for XLNet"
)
pass
def
create_mask
(
self
,
qlen
,
mlen
):
""" create causal attention mask.
...
...
@@ -708,11 +710,11 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
return
pos_emb
def
forward
(
self
,
inp
_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp
ut_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
head_mask
=
None
):
"""
Args:
inp
_k
: int32 Tensor in shape [bsz, len], the input token IDs.
inp
ut_ids
: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
...
...
@@ -751,7 +753,7 @@ class XLNetModel(XLNetPreTrainedModel):
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
inp
_k
=
inp_k
.
transpose
(
0
,
1
).
contiguous
()
inp
ut_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
token_type_ids
=
token_type_ids
.
transpose
(
0
,
1
).
contiguous
()
if
token_type_ids
is
not
None
else
None
input_mask
=
input_mask
.
transpose
(
0
,
1
).
contiguous
()
if
input_mask
is
not
None
else
None
attention_mask
=
attention_mask
.
transpose
(
0
,
1
).
contiguous
()
if
attention_mask
is
not
None
else
None
...
...
@@ -759,7 +761,7 @@ class XLNetModel(XLNetPreTrainedModel):
target_mapping
=
target_mapping
.
permute
(
1
,
2
,
0
).
contiguous
()
if
target_mapping
is
not
None
else
None
inp_q
=
inp_q
.
transpose
(
0
,
1
).
contiguous
()
if
inp_q
is
not
None
else
None
qlen
,
bsz
=
inp
_k
.
shape
[
0
],
inp
_k
.
shape
[
1
]
qlen
,
bsz
=
inp
ut_ids
.
shape
[
0
],
inp
ut_ids
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
...
...
@@ -810,7 +812,7 @@ class XLNetModel(XLNetPreTrainedModel):
non_tgt_mask
=
None
##### Word embeddings and prepare h & g hidden states
word_emb_k
=
self
.
word_embedding
(
inp
_k
)
word_emb_k
=
self
.
word_embedding
(
inp
ut_ids
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
...
...
@@ -838,20 +840,20 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
dropout
(
pos_emb
)
#
#### H
ead mask if needed
(for bertology/pruning)
#
Prepare h
ead mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [n_layer x num_heads]
# and head_mask is converted to shape [n_layer x
batch x num_heads x seq_length x seq_length
]
# input head_mask has shape [num_heads] or [n
um_hidden
_layer
s
x num_heads]
(a head_mask for each layer)
# and head_mask is converted to shape [n
um_hidden
_layer
s
x
qlen x klen x bsz x n_head
]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
)
head_mask
=
head_mask
.
expand
(
self
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
unsqueeze
(
1
)
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
head_mask
=
[
None
]
*
self
.
n_layer
new_mems
=
[]
if
mems
is
None
:
...
...
@@ -870,7 +872,7 @@ class XLNetModel(XLNetPreTrainedModel):
head_mask
=
head_mask
[
i
])
output_h
,
output_g
=
outputs
[:
2
]
if
self
.
output_attentions
:
attentions
.
append
(
outputs
[
2
:
])
attentions
.
append
(
outputs
[
2
])
# Add last hidden state
if
self
.
output_hidden_states
:
...
...
@@ -887,6 +889,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
outputs
.
append
(
hidden_states
)
if
self
.
output_attentions
:
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
return
outputs
# outputs, new_mems, (hidden_states), (attentions)
...
...
@@ -902,7 +905,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
This can be used to compute head importance metrics. Default: False
Inputs:
inp
_k
: int32 Tensor in shape [bsz, len], the input token IDs.
inp
ut_ids
: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
...
...
@@ -953,16 +956,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
XLNetLMHeadModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
transformer
=
XLNetModel
(
config
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
# Tie weights
...
...
@@ -975,12 +974,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
"""
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
def
forward
(
self
,
inp
_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp
ut_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
labels
=
None
,
head_mask
=
None
):
"""
Args:
inp
_k
: int32 Tensor in shape [bsz, len], the input token IDs.
inp
ut_ids
: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
...
...
@@ -1008,7 +1007,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
transformer_outputs
=
self
.
transformer
(
inp
_k
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
inp
ut_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
...
...
@@ -1025,14 +1024,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
class
XLNetSequenceSummary
(
nn
.
Module
):
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
):
def
__init__
(
self
,
config
):
super
(
XLNetSequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
summary_type
if
use_proj
:
self
.
summary_type
=
config
.
summary_type
if
config
.
use_proj
:
self
.
summary
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_model
)
else
:
self
.
summary
=
None
if
summary_type
==
'attn'
:
if
config
.
summary_type
==
'attn'
:
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
...
...
@@ -1069,7 +1068,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
to pool the input to get a vector representation. Default: last
Inputs:
inp
_k
: int32 Tensor in shape [bsz, len], the input token IDs.
inp
ut_ids
: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
...
...
@@ -1121,30 +1120,21 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
,
num_labels
=
2
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
summary_type
=
summary_type
self
.
num_labels
=
num_labels
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
transformer
=
XLNetModel
(
config
)
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
,
summary_type
=
summary_type
,
use_proj
=
use_proj
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
num_labels
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
inp
_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp
ut_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
labels
=
None
,
head_mask
=
None
):
"""
Args:
inp
_k
: int32 Tensor in shape [bsz, len], the input token IDs.
inp
ut_ids
: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
...
...
@@ -1169,7 +1159,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Only used during pretraining for two-stream attention.
Set to None during finetuning.
"""
transformer_outputs
=
self
.
transformer
(
inp
_k
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
inp
ut_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
output
=
transformer_outputs
[
0
]
...
...
@@ -1247,20 +1237,18 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
XLNetForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_state
s
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
transformer
=
XLNetModel
(
config
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_label
s
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
inp
_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp
ut_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
inp
_k
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
inp
ut_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
logits
=
self
.
qa_outputs
(
transformer_outputs
[
0
])
...
...
pytorch_pretrained_bert/tests/__init__.py
0 → 100644
View file @
1484d67d
tests/conftest.py
→
pytorch_pretrained_bert/
tests/conftest.py
View file @
1484d67d
File moved
sampl
es/input.txt
→
pytorch_pretrained_bert/tests/fixtur
es/input.txt
View file @
1484d67d
File moved
sampl
es/sample_text.txt
→
pytorch_pretrained_bert/tests/fixtur
es/sample_text.txt
View file @
1484d67d
File moved
sampl
es/test_sentencepiece.model
→
pytorch_pretrained_bert/tests/fixtur
es/test_sentencepiece.model
View file @
1484d67d
File moved
pytorch_pretrained_bert/tests/model_tests_commons.py
0 → 100644
View file @
1484d67d
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
shutil
import
json
import
random
import
torch
def
create_and_check_for_headmasking
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
config
.
output_hidden_states
=
True
model
=
model_class
(
config
=
config
)
model
.
eval
()
head_mask
=
torch
.
zeros
(
tester
.
num_hidden_layers
,
tester
.
num_attention_heads
)
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask
.
requires_grad_
(
requires_grad
=
True
)
outputs
=
model
(
**
inputs_dict
,
head_mask
=
head_mask
)
# Compute some gradients
output
=
sum
(
t
.
sum
()
for
t
in
outputs
[
0
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
head_mask
.
grad
tester
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
tester
.
num_hidden_layers
)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def
create_and_check_for_head_pruning
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
model
=
model_class
(
config
=
config
)
model
.
eval
()
heads_to_prune
=
{
0
:
list
(
range
(
1
,
tester
.
num_attention_heads
)),
-
1
:
[
0
]}
model
.
prune_heads
(
heads_to_prune
)
outputs
=
model
(
**
inputs_dict
)
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = bert_model.get_multihead_outputs()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def
create_and_check_for_attentions
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
model
=
model_class
(
config
)
model
.
eval
()
outputs
=
model
(
**
inputs_dict
)
attentions
=
outputs
[
-
1
]
tester
.
parent
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
tester
.
parent
.
assertEqual
(
model
.
config
.
output_hidden_states
,
False
)
tester
.
parent
.
assertEqual
(
len
(
attentions
),
tester
.
num_hidden_layers
)
tester
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
shape
[
-
3
:]),
[
tester
.
num_attention_heads
,
tester
.
seq_length
,
tester
.
key_len
if
hasattr
(
tester
,
'key_len'
)
else
tester
.
seq_length
])
out_len
=
len
(
outputs
)
# Check attention is always last and order is fine
config
.
output_attentions
=
True
config
.
output_hidden_states
=
True
model
=
model_class
(
config
)
model
.
eval
()
outputs
=
model
(
**
inputs_dict
)
tester
.
parent
.
assertEqual
(
out_len
+
1
,
len
(
outputs
))
tester
.
parent
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
tester
.
parent
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
attentions
=
outputs
[
-
1
]
tester
.
parent
.
assertEqual
(
len
(
attentions
),
tester
.
num_hidden_layers
)
tester
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
shape
[
-
3
:]),
[
tester
.
num_attention_heads
,
tester
.
seq_length
,
tester
.
key_len
if
hasattr
(
tester
,
'key_len'
)
else
tester
.
seq_length
])
def
create_and_check_for_hidden_states
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
config
.
output_hidden_states
=
True
config
.
output_attentions
=
False
model
=
model_class
(
config
)
model
.
eval
()
outputs
=
model
(
**
inputs_dict
)
hidden_states
=
outputs
[
-
1
]
tester
.
parent
.
assertEqual
(
model
.
config
.
output_attentions
,
False
)
tester
.
parent
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
tester
.
parent
.
assertEqual
(
len
(
hidden_states
),
tester
.
num_hidden_layers
+
1
)
tester
.
parent
.
assertListEqual
(
list
(
hidden_states
[
0
].
shape
[
-
2
:]),
[
tester
.
seq_length
,
tester
.
hidden_size
])
def
create_and_check_commons
(
tester
,
config
,
inputs_dict
):
create_and_check_for_attentions
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_headmasking
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_head_pruning
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_hidden_states
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
class
ConfigTester
(
object
):
def
__init__
(
self
,
parent
,
config_class
=
None
,
**
kwargs
):
self
.
parent
=
parent
self
.
config_class
=
config_class
self
.
inputs_dict
=
kwargs
def
create_and_test_config_to_json_string
(
self
):
config
=
self
.
config_class
(
**
self
.
inputs_dict
)
obj
=
json
.
loads
(
config
.
to_json_string
())
for
key
,
value
in
self
.
inputs_dict
.
items
():
self
.
parent
.
assertEqual
(
obj
[
key
],
value
)
def
create_and_test_config_to_json_file
(
self
):
config_first
=
self
.
config_class
(
**
self
.
inputs_dict
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
self
.
config_class
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
parent
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
def
run_common_tests
(
self
):
self
.
create_and_test_config_to_json_string
()
self
.
create_and_test_config_to_json_file
()
class
GPTModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_position_ids
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
n_special
=
1
,
n_positions
=
33
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
n_choices
=
3
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
scope
=
None
,
config_class
=
None
,
base_model_class
=
None
,
lm_head_model_class
=
None
,
double_head_model_class
=
None
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_position_ids
=
use_position_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
n_positions
=
n_positions
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
n_choices
=
n_choices
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
scope
=
scope
self
.
config_class
=
config_class
self
.
base_model_class
=
base_model_class
self
.
lm_head_model_class
=
lm_head_model_class
self
.
double_head_model_class
=
double_head_model_class
self
.
all_model_classes
=
(
base_model_class
,
lm_head_model_class
,
double_head_model_class
)
def
prepare_config_and_inputs
(
self
):
total_num_tokens
=
self
.
vocab_size
+
self
.
n_special
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_num_tokens
)
position_ids
=
None
if
self
.
use_position_ids
:
position_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
n_positions
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
total_voc
=
self
.
vocab_size
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_voc
)
mc_labels
=
None
lm_labels
=
None
mc_token_ids
=
None
if
self
.
use_labels
:
mc_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
num_labels
)
mc_token_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
],
self
.
seq_length
)
config
=
self
.
config_class
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_special
=
self
.
n_special
,
n_positions
=
self
.
n_positions
,
n_embd
=
self
.
hidden_size
,
n_layer
=
self
.
num_hidden_layers
,
n_head
=
self
.
num_attention_heads
,
initializer_range
=
self
.
initializer_range
)
return
(
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
)
def
create_and_check_base_model
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
self
.
base_model_class
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
hidden_state
=
outputs
[
0
]
self
.
parent
.
assertListEqual
(
list
(
hidden_state
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
self
.
lm_head_model_class
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
position_ids
,
token_type_ids
,
lm_labels
)
loss
,
lm_logits
=
outputs
[:
2
]
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
lm_logits
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
self
.
parent
.
assertListEqual
(
list
(
loss
.
size
()),
[])
def
create_and_check_presents
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
)
presents
=
outputs
[
-
1
]
self
.
parent
.
assertEqual
(
self
.
num_hidden_layers
,
len
(
presents
))
self
.
parent
.
assertListEqual
(
list
(
presents
[
0
].
size
()),
[
2
,
self
.
batch_size
*
self
.
n_choices
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
def
create_and_check_double_heads
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
self
.
double_head_model_class
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
mc_token_ids
,
lm_labels
=
lm_labels
,
mc_labels
=
mc_labels
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
lm_loss
,
mc_loss
,
lm_logits
,
mc_logits
=
outputs
[:
4
]
loss
=
[
lm_loss
,
mc_loss
]
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
lm_logits
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
self
.
parent
.
assertListEqual
(
list
(
mc_logits
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
])
self
.
parent
.
assertListEqual
(
[
list
(
l
.
size
())
for
l
in
loss
],
[[],
[]])
def
create_and_check_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
self
.
base_model_class
.
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
self
.
base_model_class
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
parent
.
assertIsNotNone
(
model
)
def
create_and_check_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
inputs_dict
=
{
'input_ids'
:
input_ids
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
def
run_common_tests
(
self
,
test_presents
=
False
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_base_model
(
*
config_and_inputs
)
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_lm_head
(
*
config_and_inputs
)
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_double_heads
(
*
config_and_inputs
)
if
test_presents
:
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_presents
(
*
config_and_inputs
)
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_commons
(
*
config_and_inputs
)
def
run_slow_tests
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_model_from_pretrained
(
*
config_and_inputs
)
pytorch_pretrained_bert/tests/model_utils_test.py
0 → 100644
View file @
1484d67d
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
PretrainedConfig
,
PreTrainedModel
from
pytorch_pretrained_bert.modeling
import
BertModel
,
BertConfig
,
PRETRAINED_MODEL_ARCHIVE_MAP
,
PRETRAINED_CONFIG_ARCHIVE_MAP
class
ModelUtilsTest
(
unittest
.
TestCase
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
config
=
BertConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
PretrainedConfig
)
model
=
BertModel
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
PreTrainedModel
)
config
=
BertConfig
.
from_pretrained
(
model_name
,
output_attentions
=
True
,
output_hidden_states
=
True
)
model
=
BertModel
.
from_pretrained
(
model_name
,
output_attentions
=
True
,
output_hidden_states
=
True
)
self
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
self
.
assertEqual
(
model
.
config
,
config
)
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
0 → 100644
View file @
1484d67d
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
from
.model_tests_commons
import
(
create_and_check_for_attentions
,
create_and_check_for_head_pruning
,
create_and_check_for_headmasking
,
create_and_check_for_hidden_states
,
ConfigTester
,
GPTModelTester
)
class
GPT2ModelTest
(
unittest
.
TestCase
):
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
GPT2Config
,
n_embd
=
37
)
config_tester
.
run_common_tests
()
def
test_model
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_common_tests
(
test_presents
=
True
)
@
pytest
.
mark
.
slow
def
test_pretrained
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_slow_tests
()
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_pretrained_bert/tests/modeling_openai_test.py
0 → 100644
View file @
1484d67d
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
.model_tests_commons
import
(
create_and_check_for_attentions
,
create_and_check_for_head_pruning
,
create_and_check_for_headmasking
,
create_and_check_for_hidden_states
,
ConfigTester
,
GPTModelTester
)
class
OpenAIModelTest
(
unittest
.
TestCase
):
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
OpenAIGPTConfig
,
n_embd
=
37
)
config_tester
.
run_common_tests
()
def
test_model
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
OpenAIGPTConfig
,
base_model_class
=
OpenAIGPTModel
,
lm_head_model_class
=
OpenAIGPTLMHeadModel
,
double_head_model_class
=
OpenAIGPTDoubleHeadsModel
)
model_tester
.
run_common_tests
(
test_presents
=
False
)
@
pytest
.
mark
.
slow
def
test_pretrained
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
OpenAIGPTConfig
,
base_model_class
=
OpenAIGPTModel
,
lm_head_model_class
=
OpenAIGPTLMHeadModel
,
double_head_model_class
=
OpenAIGPTDoubleHeadsModel
)
model_tester
.
run_slow_tests
()
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/modeling_test.py
→
pytorch_pretrained_bert/
tests/modeling_test.py
View file @
1484d67d
...
...
@@ -31,6 +31,8 @@ from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForTokenClassification
,
BertForMultipleChoice
)
from
pytorch_pretrained_bert.modeling
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
ids_tensor
)
class
BertModelTest
(
unittest
.
TestCase
):
class
BertModelTester
(
object
):
...
...
@@ -57,7 +59,11 @@ class BertModelTest(unittest.TestCase):
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
):
scope
=
None
,
all_model_classes
=
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
),
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
...
...
@@ -80,25 +86,26 @@ class BertModelTest(unittest.TestCase):
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
input_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
BertConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
...
@@ -120,136 +127,117 @@ class BertModelTest(unittest.TestCase):
list
(
result
[
"loss"
].
size
()),
[])
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_
and_check_
bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertModel
(
config
=
config
)
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
model
=
BertModel
(
config
=
config
,
output_hidden_states
=
True
)
model
.
eval
()
_
,
_
,
all_encoder_layers
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
result
=
{
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
"all_encoder_layers"
:
all_encoder_layers
,
}
return
outputs
def
check_bert_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
[
size
for
layer
in
result
[
"all_encoder_layers"
]
for
size
in
layer
.
size
()],
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
*
(
self
.
num_hidden_layers
+
1
))
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_
and_check_
bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMaskedLM
(
config
=
config
)
model
.
eval
()
loss
,
prediction_scores
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
}
return
outputs
def
check_bert_for_masked_lm_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
check_loss_output
(
result
)
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_
and_check_
bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
.
eval
()
loss
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"seq_relationship_score"
:
seq_relationship_score
,
}
return
outputs
def
check_bert_for_next_sequence_prediction_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_score"
].
size
()),
[
self
.
batch_size
,
2
])
self
.
check_loss_output
(
result
)
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_
and_check_
bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForPreTraining
(
config
=
config
)
model
.
eval
()
loss
,
prediction_scores
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"seq_relationship_score"
:
seq_relationship_score
,
}
return
outputs
def
check_bert_for_pretraining_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_score"
].
size
()),
[
self
.
batch_size
,
2
])
self
.
check_loss_output
(
result
)
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_
and_check_
bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForQuestionAnswering
(
config
=
config
)
model
.
eval
()
loss
,
start_logits
,
end_logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
return
outputs
def
check_bert_for_question_answering_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
def
create_and_check_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
BertForSequenceClassification
(
config
)
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_sequence_classification_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
def
create_and_check_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
BertForTokenClassification
(
config
=
config
)
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_token_classification_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
def
create_bert_for_multiple_choice
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMultipleChoice
(
config
=
config
,
num_choices
=
self
.
num_choices
)
def
create_and_check_bert_for_multiple_choice
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_choices
=
self
.
num_choices
model
=
BertForMultipleChoice
(
config
=
config
)
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
...
...
@@ -258,148 +246,26 @@ class BertModelTest(unittest.TestCase):
multiple_choice_token_type_ids
,
multiple_choice_input_mask
,
choice_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
return
outputs
def
check_bert_for_multiple_choice
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
check_loss_output
(
result
)
def
create_and_check_bert_for_attentions
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
output_attentions
=
True
)
else
:
model
=
model_class
(
config
=
config
,
output_attentions
=
True
)
model
.
eval
()
outputs
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
attentions
=
outputs
[
-
1
]
self
.
parent
.
assertEqual
(
len
(
attentions
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
seq_length
])
def
create_and_check_bert_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
)
else
:
model
=
model_class
(
config
=
config
)
model
.
eval
()
head_mask
=
torch
.
ones
(
self
.
num_hidden_layers
,
self
.
num_attention_heads
).
to
(
input_ids
.
device
)
head_mask
[
0
,
1
:
-
1
]
=
0.0
# Mask all but the first and last heads on the first layer
head_mask
[
-
1
,
1
:]
=
0.0
# Mask all but the first head on the last layer
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask
.
requires_grad_
(
requires_grad
=
True
)
outputs
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
head_mask
=
head_mask
)
# Compute some gradients
output
=
sum
(
t
.
sum
()
for
t
in
outputs
[
0
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
head_mask
.
grad
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def
create_and_check_bert_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
)
else
:
model
=
model_class
(
config
=
config
)
model
.
eval
()
bert_model
=
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
num_attention_heads
)),
-
1
:
[
0
]}
bert_model
.
prune_heads
(
heads_to_prune
)
outputs
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = bert_model.get_multihead_outputs()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def
create_and_check_bert_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
BertConfig
(
vocab_size_or_config_json_file
=
99
,
hidden_size
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"vocab_size"
],
99
)
self
.
assertEqual
(
obj
[
"hidden_size"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
BertConfig
(
vocab_size_or_config_json_file
=
99
,
hidden_size
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
BertConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
BertConfig
,
hidden_size
=
37
)
config_tester
.
run_common_tests
()
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
...
...
@@ -411,57 +277,31 @@ class BertModelTest(unittest.TestCase):
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_bert_model
(
*
config_and_inputs
)
tester
.
check_bert_model_output
(
output_result
)
output_result
=
tester
.
create_bert_for_masked_lm
(
*
config_and_inputs
)
tester
.
check_bert_for_masked_lm_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
tester
.
check_bert_for_next_sequence_prediction_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_pretraining
(
*
config_and_inputs
)
tester
.
check_bert_for_pretraining_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_question_answering
(
*
config_and_inputs
)
tester
.
check_bert_for_question_answering_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
tester
.
create_and_check_bert_model
(
*
config_and_inputs
)
output_result
=
tester
.
create_bert_for_sequence_classification
(
*
config_and_inputs
)
tester
.
check_bert_for_sequence_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_token_classification
(
*
config_and_inputs
)
tester
.
check_bert_for_token_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_masked_lm
(
*
config_and_inputs
)
output_result
=
tester
.
create_bert_for_multiple_choice
(
*
config_and_inputs
)
tester
.
check_bert_for_multiple_choice
(
output_result
)
tester
.
check_loss_output
(
output_result
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_multiple_choice
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_head_pruning
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_pretraining
(
*
config_and_inputs
)
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_question_answering
(
*
config_and_inputs
)
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_sequence_classification
(
*
config_and_inputs
)
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_token_classification
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_commons
(
*
config_and_inputs
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/modeling_transfo_xl_test.py
→
pytorch_pretrained_bert/
tests/modeling_transfo_xl_test.py
View file @
1484d67d
...
...
@@ -28,6 +28,8 @@ import torch
from
pytorch_pretrained_bert
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
)
from
pytorch_pretrained_bert.modeling_transfo_xl
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
.model_tests_commons
import
ConfigTester
,
create_and_check_commons
,
ids_tensor
class
TransfoXLModelTest
(
unittest
.
TestCase
):
class
TransfoXLModelTester
(
object
):
...
...
@@ -41,54 +43,58 @@ class TransfoXLModelTest(unittest.TestCase):
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
d_model
=
32
,
hidden_size
=
32
,
d_embed
=
32
,
n_head
=
4
,
num_attentio
n_head
s
=
4
,
d_head
=
8
,
d_inner
=
128
,
div_val
=
2
,
n_layer
=
5
,
num_hidde
n_layer
s
=
5
,
scope
=
None
,
seed
=
1
):
seed
=
1
,
all_model_classes
=
(
TransfoXLModel
,
TransfoXLLMHeadModel
),
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
key_len
=
seq_length
+
mem_len
self
.
clamp_len
=
clamp_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
d_model
=
d_model
self
.
hidden_size
=
hidden_size
self
.
d_embed
=
d_embed
self
.
n_head
=
n_head
self
.
n
um_attention
_head
s
=
num_attentio
n_head
s
self
.
d_head
=
d_head
self
.
d_inner
=
d_inner
self
.
div_val
=
div_val
self
.
n_layer
=
n_layer
self
.
n
um_hidden
_layer
s
=
num_hidde
n_layer
s
self
.
scope
=
scope
self
.
seed
=
seed
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
TransfoXLModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
TransfoXLModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
None
if
self
.
use_labels
:
lm_labels
=
TransfoXLModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
cutoffs
=
self
.
cutoffs
,
d_model
=
self
.
d_model
,
d_model
=
self
.
hidden_size
,
d_embed
=
self
.
d_embed
,
n_head
=
self
.
n_head
,
n_head
=
self
.
n
um_attention
_head
s
,
d_head
=
self
.
d_head
,
d_inner
=
self
.
d_inner
,
div_val
=
self
.
div_val
,
n_layer
=
self
.
n_layer
)
n_layer
=
self
.
n
um_hidden
_layer
s
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
...
...
@@ -113,37 +119,34 @@ class TransfoXLModelTest(unittest.TestCase):
def
check_transfo_xl_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
d_model
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
d_model
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
n
um_hidden
_layer
s
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
n
um_hidden
_layer
s
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TransfoXLLMHeadModel
(
config
)
model
.
eval
()
loss_1
,
mems_1a
=
model
(
input_ids_1
,
labels
=
lm_labels
)
lm_logits_1
,
mems_1b
=
model
(
input_ids_1
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
labels
=
lm_labels
,
mems
=
mems_1a
)
lm_logits_2
,
mems_2b
=
model
(
input_ids_2
,
mems
=
mems_1b
)
lm_logits_1
,
mems_1
=
model
(
input_ids_1
)
loss_1
,
_
,
mems_1
=
model
(
input_ids_1
,
labels
=
lm_labels
)
lm_logits_2
,
mems_2
=
model
(
input_ids_2
,
mems
=
mems_1
)
loss_2
,
_
,
mems_2
=
model
(
input_ids_2
,
labels
=
lm_labels
,
mems
=
mems_1
)
outputs
=
{
"loss_1"
:
loss_1
,
"mems_1
a
"
:
mems_1
a
,
"mems_1"
:
mems_1
,
"lm_logits_1"
:
lm_logits_1
,
"mems_1b"
:
mems_1b
,
"loss_2"
:
loss_2
,
"mems_2
a
"
:
mems_2
a
,
"mems_2"
:
mems_2
,
"lm_logits_2"
:
lm_logits_2
,
"mems_2b"
:
mems_2b
,
}
return
outputs
...
...
@@ -155,14 +158,8 @@ class TransfoXLModelTest(unittest.TestCase):
list
(
result
[
"lm_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1b"
]))
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
...
...
@@ -171,31 +168,19 @@ class TransfoXLModelTest(unittest.TestCase):
list
(
result
[
"lm_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2b"
]))
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_transfo_xl_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
def
test_default
(
self
):
self
.
run_tester
(
TransfoXLModelTest
.
TransfoXLModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
96
,
d_embed
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_embed"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
96
,
d_embed
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
TransfoXLConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
config_tester
.
run_common_tests
()
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
...
...
@@ -209,28 +194,18 @@ class TransfoXLModelTest(unittest.TestCase):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_model
(
*
config_and_inputs
)
tester
.
check_transfo_xl_model_output
(
output_result
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_transfo_xl_commons
(
*
config_and_inputs
)
if
__name__
==
"__main__"
:
...
...
tests/modeling_xlnet_test.py
→
pytorch_pretrained_bert/
tests/modeling_xlnet_test.py
View file @
1484d67d
...
...
@@ -25,9 +25,11 @@ import pytest
import
torch
from
pytorch_pretrained_bert
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
)
from
pytorch_pretrained_bert
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
from
pytorch_pretrained_bert.modeling_xlnet
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
.model_tests_commons
import
ConfigTester
,
create_and_check_commons
,
ids_tensor
class
XLNetModelTest
(
unittest
.
TestCase
):
class
XLNetModelTester
(
object
):
...
...
@@ -42,43 +44,48 @@ class XLNetModelTest(unittest.TestCase):
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
d_model
=
32
,
n_head
=
4
,
hidden_size
=
32
,
num_attentio
n_head
s
=
4
,
d_inner
=
128
,
n_layer
=
5
,
num_hidde
n_layer
s
=
5
,
max_position_embeddings
=
10
,
untie_r
=
True
,
bi_data
=
False
,
same_length
=
False
,
seed
=
1
,
type_vocab_size
=
2
):
type_vocab_size
=
2
,
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
),
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
# self.key_len = seq_length + mem_len
self
.
clamp_len
=
clamp_len
self
.
reuse_len
=
reuse_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
d_model
=
d_model
self
.
n_head
=
n_head
self
.
hidden_size
=
hidden_size
self
.
n
um_attention
_head
s
=
num_attentio
n_head
s
self
.
d_inner
=
d_inner
self
.
n_layer
=
n_layer
self
.
n
um_hidden
_layer
s
=
num_hidde
n_layer
s
self
.
max_position_embeddings
=
max_position_embeddings
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_ids_q
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
...
...
@@ -89,8 +96,8 @@ class XLNetModelTest(unittest.TestCase):
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [bsz, mem_len,
d_model
], memory
# from previous batches. The length of the list equals n_layer.
# mems: a list of float32 Tensors in shape [bsz, mem_len,
hidden_size
], memory
# from previous batches. The length of the list equals
num_hidde
n_layer
s
.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
...
...
@@ -108,14 +115,14 @@ class XLNetModelTest(unittest.TestCase):
lm_labels
=
None
if
self
.
use_labels
:
lm_labels
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
d_model
=
self
.
d_model
,
n_head
=
self
.
n_head
,
d_model
=
self
.
hidden_size
,
n_head
=
self
.
n
um_attention
_head
s
,
d_inner
=
self
.
d_inner
,
n_layer
=
self
.
n_layer
,
n_layer
=
self
.
n
um_hidden
_layer
s
,
untie_r
=
self
.
untie_r
,
max_position_embeddings
=
self
.
max_position_embeddings
,
mem_len
=
self
.
mem_len
,
...
...
@@ -159,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
n
um_hidden
_layer
s
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
...
...
@@ -169,24 +176,18 @@ class XLNetModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
def
test_default
(
self
):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
16
*
4
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_model"
],
16
*
4
)
def
test_config_to_json_file
(
self
):
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
16
*
4
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
XLNetConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
config_tester
.
run_common_tests
()
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
...
...
@@ -197,27 +198,14 @@ class XLNetModelTest(unittest.TestCase):
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_commons
(
*
config_and_inputs
)
@
classmethod
def
mask_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
tests/optimization_test.py
→
pytorch_pretrained_bert/
tests/optimization_test.py
View file @
1484d67d
File moved
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment