Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1484d67d
Commit
1484d67d
authored
Jul 02, 2019
by
thomwolf
Browse files
[LARGE] updating all tests and API
parent
4f8b5f68
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1479 additions
and
1347 deletions
+1479
-1347
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+29
-47
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+47
-65
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+90
-212
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+78
-82
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+147
-57
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+95
-688
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+64
-76
pytorch_pretrained_bert/tests/__init__.py
pytorch_pretrained_bert/tests/__init__.py
+0
-0
pytorch_pretrained_bert/tests/conftest.py
pytorch_pretrained_bert/tests/conftest.py
+0
-0
pytorch_pretrained_bert/tests/fixtures/input.txt
pytorch_pretrained_bert/tests/fixtures/input.txt
+0
-0
pytorch_pretrained_bert/tests/fixtures/sample_text.txt
pytorch_pretrained_bert/tests/fixtures/sample_text.txt
+0
-0
pytorch_pretrained_bert/tests/fixtures/test_sentencepiece.model
...h_pretrained_bert/tests/fixtures/test_sentencepiece.model
+0
-0
pytorch_pretrained_bert/tests/model_tests_commons.py
pytorch_pretrained_bert/tests/model_tests_commons.py
+379
-0
pytorch_pretrained_bert/tests/model_utils_test.py
pytorch_pretrained_bert/tests/model_utils_test.py
+50
-0
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
+55
-0
pytorch_pretrained_bert/tests/modeling_openai_test.py
pytorch_pretrained_bert/tests/modeling_openai_test.py
+55
-0
pytorch_pretrained_bert/tests/modeling_test.py
pytorch_pretrained_bert/tests/modeling_test.py
+307
-0
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
+45
-70
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
+38
-50
pytorch_pretrained_bert/tests/optimization_test.py
pytorch_pretrained_bert/tests/optimization_test.py
+0
-0
No files found.
pytorch_pretrained_bert/model_utils.py
View file @
1484d67d
...
@@ -41,6 +41,12 @@ class PretrainedConfig(object):
...
@@ -41,6 +41,12 @@ class PretrainedConfig(object):
"""
"""
pretrained_config_archive_map
=
{}
pretrained_config_archive_map
=
{}
def
__init__
(
self
,
**
kwargs
):
self
.
finetuning_task
=
kwargs
.
pop
(
'finetuning_task'
,
None
)
self
.
num_labels
=
kwargs
.
pop
(
'num_labels'
,
2
)
self
.
output_attentions
=
kwargs
.
pop
(
'output_attentions'
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
'output_hidden_states'
,
False
)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
**
kwargs
):
"""
"""
...
@@ -114,6 +120,9 @@ class PretrainedConfig(object):
...
@@ -114,6 +120,9 @@ class PretrainedConfig(object):
text
=
reader
.
read
()
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
return
cls
.
from_dict
(
json
.
loads
(
text
))
def
__eq__
(
self
,
other
):
return
self
.
__dict__
==
other
.
__dict__
def
__repr__
(
self
):
def
__repr__
(
self
):
return
str
(
self
.
to_json_string
())
return
str
(
self
.
to_json_string
())
...
@@ -133,12 +142,11 @@ class PretrainedConfig(object):
...
@@ -133,12 +142,11 @@ class PretrainedConfig(object):
class
PreTrainedModel
(
nn
.
Module
):
class
PreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle
weights initialization
and
""" An abstract class to handle
storing model config
and
a simple interface for dowloading and loading pretrained models.
a simple interface for dowloading and loading pretrained models.
"""
"""
config_class
=
PretrainedConfig
config_class
=
PretrainedConfig
pretrained_model_archive_map
=
{}
pretrained_model_archive_map
=
{}
pretrained_config_archive_map
=
{}
load_tf_weights
=
lambda
model
,
config
,
path
:
None
load_tf_weights
=
lambda
model
,
config
,
path
:
None
base_model_prefix
=
""
base_model_prefix
=
""
...
@@ -151,8 +159,16 @@ class PreTrainedModel(nn.Module):
...
@@ -151,8 +159,16 @@ class PreTrainedModel(nn.Module):
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
))
# Save config in model
self
.
config
=
config
self
.
config
=
config
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the base model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
model_to_prune
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
# get the base model if needed
model_to_prune
.
_prune_heads
(
heads_to_prune
)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
"""
...
@@ -175,24 +191,22 @@ class PreTrainedModel(nn.Module):
...
@@ -175,24 +191,22 @@ class PreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific XLNet class
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
(ex: num_labels for XLNetForSequenceClassification)
"""
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
state_dict
=
kwargs
.
pop
(
'state_dict'
,
None
)
kwargs
.
pop
(
'
stat
e_di
ct
'
,
None
)
cache_dir
=
kwargs
.
pop
(
'
cach
e_di
r
'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir
'
,
None
)
from_tf
=
kwargs
.
pop
(
'from_tf
'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
# Load config
kwargs
.
pop
(
'from_tf'
,
None
)
config
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
# Load model
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
else
:
else
:
if
from_tf
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
else
:
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
...
@@ -210,47 +224,15 @@ class PreTrainedModel(nn.Module):
...
@@ -210,47 +224,15 @@ class PreTrainedModel(nn.Module):
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
archive_file
))
return
None
return
None
try
:
if
resolved_archive_file
==
archive_file
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
cls
.
config_class
.
from_json_file
(
resolved_config_file
)
# Update config with kwargs if needed
to_remove
=
[]
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
to_remove
.
append
(
key
)
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
)
if
state_dict
is
None
and
not
from_tf
:
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
if
from_tf
:
...
@@ -275,7 +257,7 @@ class PreTrainedModel(nn.Module):
...
@@ -275,7 +257,7 @@ class PreTrainedModel(nn.Module):
if
child
is
not
None
:
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
load
(
child
,
prefix
+
name
+
'.'
)
#
B
e able to load base models as well as derived models (with heads)
#
Make sure we ar
e able to load base models as well as derived models (with heads)
start_prefix
=
''
start_prefix
=
''
model_to_load
=
model
model_to_load
=
model
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
...
...
pytorch_pretrained_bert/modeling.py
View file @
1484d67d
...
@@ -155,7 +155,7 @@ class BertConfig(PretrainedConfig):
...
@@ -155,7 +155,7 @@ class BertConfig(PretrainedConfig):
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size_or_config_json_file
,
vocab_size_or_config_json_file
=
30522
,
hidden_size
=
768
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
num_attention_heads
=
12
,
...
@@ -167,7 +167,7 @@ class BertConfig(PretrainedConfig):
...
@@ -167,7 +167,7 @@ class BertConfig(PretrainedConfig):
type_vocab_size
=
2
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
layer_norm_eps
=
1e-12
,
finetuning_task
=
None
):
**
kwargs
):
"""Constructs BertConfig.
"""Constructs BertConfig.
Args:
Args:
...
@@ -192,8 +192,8 @@ class BertConfig(PretrainedConfig):
...
@@ -192,8 +192,8 @@ class BertConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
layer_norm_eps: The epsilon used by LayerNorm.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
"""
super
(
BertConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
...
@@ -213,7 +213,6 @@ class BertConfig(PretrainedConfig):
...
@@ -213,7 +213,6 @@ class BertConfig(PretrainedConfig):
self
.
type_vocab_size
=
type_vocab_size
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
layer_norm_eps
=
layer_norm_eps
self
.
finetuning_task
=
finetuning_task
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
"or the path to a pretrained model config file (str)"
)
...
@@ -270,13 +269,13 @@ class BertEmbeddings(nn.Module):
...
@@ -270,13 +269,13 @@ class BertEmbeddings(nn.Module):
class
BertSelfAttention
(
nn
.
Module
):
class
BertSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
BertSelfAttention
,
self
).
__init__
()
super
(
BertSelfAttention
,
self
).
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
...
@@ -344,10 +343,9 @@ class BertSelfOutput(nn.Module):
...
@@ -344,10 +343,9 @@ class BertSelfOutput(nn.Module):
class
BertAttention
(
nn
.
Module
):
class
BertAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
BertAttention
,
self
).
__init__
()
super
(
BertAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
self
=
BertSelfAttention
(
config
)
self
.
self
=
BertSelfAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
output
=
BertSelfOutput
(
config
)
self
.
output
=
BertSelfOutput
(
config
)
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
...
@@ -404,10 +402,9 @@ class BertOutput(nn.Module):
...
@@ -404,10 +402,9 @@ class BertOutput(nn.Module):
class
BertLayer
(
nn
.
Module
):
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
BertLayer
,
self
).
__init__
()
super
(
BertLayer
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
attention
=
BertAttention
(
config
)
self
.
attention
=
BertAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
output
=
BertOutput
(
config
)
self
.
output
=
BertOutput
(
config
)
...
@@ -421,11 +418,11 @@ class BertLayer(nn.Module):
...
@@ -421,11 +418,11 @@ class BertLayer(nn.Module):
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertEncoder
,
self
).
__init__
()
super
(
BertEncoder
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
output_hidden_states
=
config
.
output_hidden_states
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
)
layer
=
BertLayer
(
config
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
...
@@ -546,9 +543,6 @@ class BertPreTrainedModel(PreTrainedModel):
...
@@ -546,9 +543,6 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_bert
load_tf_weights
=
load_tf_weights_in_bert
base_model_prefix
=
"bert"
base_model_prefix
=
"bert"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
BertPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
...
@@ -612,19 +606,19 @@ class BertModel(BertPreTrainedModel):
...
@@ -612,19 +606,19 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertModel
,
self
).
__init__
(
config
)
super
(
BertModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
,
self
.
encoder
=
BertEncoder
(
config
)
output_hidden_states
=
output_hidden_states
)
self
.
pooler
=
BertPooler
(
config
)
self
.
pooler
=
BertPooler
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
prune_heads
(
self
,
heads_to_prune
):
def
_
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
...
@@ -730,14 +724,12 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -730,14 +724,12 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForPreTraining
,
self
).
__init__
(
config
)
super
(
BertForPreTraining
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
self
.
bert
=
BertModel
(
config
)
output_hidden_states
=
output_hidden_states
)
self
.
cls
=
BertPreTrainingHeads
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
cls
=
BertPreTrainingHeads
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
...
@@ -809,13 +801,12 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -809,13 +801,12 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForMaskedLM
,
self
).
__init__
(
config
)
super
(
BertForMaskedLM
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
cls
=
BertOnlyMLMHead
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
cls
=
BertOnlyMLMHead
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
head_mask
=
None
):
...
@@ -880,12 +871,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
...
@@ -880,12 +871,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForNextSentencePrediction
,
self
).
__init__
(
config
)
super
(
BertForNextSentencePrediction
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
cls
=
BertOnlyNSPHead
(
config
)
self
.
cls
=
BertOnlyNSPHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -954,15 +943,13 @@ class BertForSequenceClassification(BertPreTrainedModel):
...
@@ -954,15 +943,13 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForSequenceClassification
,
self
).
__init__
(
config
)
super
(
BertForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_labels
=
config
.
num_labels
self
.
output_hidden_states
=
output_hidden_states
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
config
.
num_labels
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -997,7 +984,6 @@ class BertForMultipleChoice(BertPreTrainedModel):
...
@@ -997,7 +984,6 @@ class BertForMultipleChoice(BertPreTrainedModel):
`config`: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
...
@@ -1030,25 +1016,23 @@ class BertForMultipleChoice(BertPreTrainedModel):
...
@@ -1030,25 +1016,23 @@ class BertForMultipleChoice(BertPreTrainedModel):
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config)
model = BertForMultipleChoice(config, num_choices)
logits = model(input_ids, token_type_ids, input_mask)
logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
num_choices
=
2
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForMultipleChoice
,
self
).
__init__
(
config
)
super
(
BertForMultipleChoice
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
num_choices
=
num_choices
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
""" Input shapes should be [bsz, num choices, seq length] """
num_choices
=
input_ids
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
...
@@ -1057,7 +1041,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
...
@@ -1057,7 +1041,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
pooled_output
=
self
.
dropout
(
pooled_output
)
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
self
.
num_choices
)
reshaped_logits
=
logits
.
view
(
-
1
,
num_choices
)
outputs
=
[
reshaped_logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
[
reshaped_logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
...
@@ -1118,15 +1102,13 @@ class BertForTokenClassification(BertPreTrainedModel):
...
@@ -1118,15 +1102,13 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForTokenClassification
,
self
).
__init__
(
config
)
super
(
BertForTokenClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
num_labels
=
config
.
num_labels
self
.
output_hidden_states
=
output_hidden_states
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -1204,12 +1186,12 @@ class BertForQuestionAnswering(BertPreTrainedModel):
...
@@ -1204,12 +1186,12 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
BertForQuestionAnswering
,
self
).
__init__
(
config
)
super
(
BertForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attention
s
self
.
num_labels
=
config
.
num_label
s
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
bert
=
BertModel
(
config
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
1484d67d
...
@@ -119,7 +119,8 @@ class GPT2Config(PretrainedConfig):
...
@@ -119,7 +119,8 @@ class GPT2Config(PretrainedConfig):
attn_pdrop
=
0.1
,
attn_pdrop
=
0.1
,
layer_norm_epsilon
=
1e-5
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
predict_special_tokens
=
True
predict_special_tokens
=
True
,
**
kwargs
):
):
"""Constructs GPT2Config.
"""Constructs GPT2Config.
...
@@ -142,6 +143,8 @@ class GPT2Config(PretrainedConfig):
...
@@ -142,6 +143,8 @@ class GPT2Config(PretrainedConfig):
initializing all weight matrices.
initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
predict_special_tokens: should we predict special tokens (when the model has a LM head)
"""
"""
super
(
GPT2Config
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
...
@@ -174,8 +177,10 @@ class GPT2Config(PretrainedConfig):
...
@@ -174,8 +177,10 @@ class GPT2Config(PretrainedConfig):
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
):
super
(
Attention
,
self
).
__init__
()
super
(
Attention
,
self
).
__init__
()
self
.
output_attentions
=
config
.
output_attentions
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert
n_state
%
config
.
n_head
==
0
assert
n_state
%
config
.
n_head
==
0
...
@@ -184,10 +189,6 @@ class Attention(nn.Module):
...
@@ -184,10 +189,6 @@ class Attention(nn.Module):
self
.
split_size
=
n_state
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
scale
=
scale
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
...
@@ -224,9 +225,10 @@ class Attention(nn.Module):
...
@@ -224,9 +225,10 @@ class Attention(nn.Module):
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
w
=
w
*
head_mask
w
=
w
*
head_mask
outputs
=
[
torch
.
matmul
(
w
,
v
)]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
w
,
torch
.
matmul
(
w
,
v
)
outputs
.
append
(
w
)
return
torch
.
matmul
(
w
,
v
)
return
outputs
def
merge_heads
(
self
,
x
):
def
merge_heads
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
...
@@ -253,19 +255,15 @@ class Attention(nn.Module):
...
@@ -253,19 +255,15 @@ class Attention(nn.Module):
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
# transpose to have same shapes for stacking
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
# transpose to have same shapes for stacking
a
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
attn_outputs
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
if
self
.
keep_multihead_output
:
a
=
attn_outputs
[
0
]
self
.
multihead_output
=
a
self
.
multihead_output
.
retain_grad
()
if
self
.
output_attentions
:
attentions
,
a
=
a
a
=
self
.
merge_heads
(
a
)
a
=
self
.
merge_heads
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
resid_dropout
(
a
)
a
=
self
.
resid_dropout
(
a
)
if
self
.
output_attentions
:
return
attentions
,
a
,
present
outputs
=
[
a
,
present
]
+
attn_outputs
[
1
:]
return
a
,
present
return
outputs
# a, present, (attentions)
class
MLP
(
nn
.
Module
):
class
MLP
(
nn
.
Module
):
...
@@ -284,27 +282,24 @@ class MLP(nn.Module):
...
@@ -284,27 +282,24 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
):
super
(
Block
,
self
).
__init__
()
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
nx
=
config
.
n_embd
self
.
output_attentions
=
output_attentions
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
,
keep_multihead_output
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
def
forward
(
self
,
x
,
layer_past
=
None
,
head_mask
=
None
):
def
forward
(
self
,
x
,
layer_past
=
None
,
head_mask
=
None
):
output_attn
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_past
,
head_mask
=
head_mask
)
output_attn
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_past
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
a
=
output_attn
[
0
]
# output_attn: a, present, (attentions)
attentions
,
a
,
present
=
output_attn
else
:
a
,
present
=
output_attn
x
=
x
+
a
x
=
x
+
a
m
=
self
.
mlp
(
self
.
ln_2
(
x
))
m
=
self
.
mlp
(
self
.
ln_2
(
x
))
x
=
x
+
m
x
=
x
+
m
if
self
.
output_attentions
:
return
attentions
,
x
,
present
outputs
=
[
x
]
+
output_attn
[
1
:]
return
x
,
present
return
outputs
# x, present, (attentions)
class
GPT2LMHead
(
nn
.
Module
):
class
GPT2LMHead
(
nn
.
Module
):
...
@@ -342,12 +337,17 @@ class GPT2MultipleChoiceHead(nn.Module):
...
@@ -342,12 +337,17 @@ class GPT2MultipleChoiceHead(nn.Module):
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
def
forward
(
self
,
hidden_states
,
mc_token_ids
):
def
forward
(
self
,
hidden_states
,
mc_token_ids
=
None
):
# Classification logits
""" Extract classification token hidden state and project it using self.linear
# hidden_state (bsz, num_choices, seq_length, hidden_size)
hidden_state: shape (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices)
mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
mc_token_ids
=
mc_token_ids
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
hidden_states
.
size
(
-
1
))
if mc_token_ids=None we take the last token of the sequence as classification token
# (bsz, num_choices, 1, hidden_size)
"""
if
mc_token_ids
is
None
:
mc_token_ids
=
torch
.
full_like
(
hidden_states
[:,
:,
:
1
,
:],
hidden_states
.
shape
[
2
]
-
1
,
dtype
=
torch
.
long
)
else
:
mc_token_ids
=
mc_token_ids
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
hidden_states
.
size
(
-
1
))
# mc_token_ids has shape (bsz, num_choices, 1, hidden_size)
multiple_choice_h
=
hidden_states
.
gather
(
2
,
mc_token_ids
).
squeeze
(
2
)
multiple_choice_h
=
hidden_states
.
gather
(
2
,
mc_token_ids
).
squeeze
(
2
)
# (bsz, num_choices, hidden_size)
# (bsz, num_choices, hidden_size)
multiple_choice_h
=
self
.
dropout
(
multiple_choice_h
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
multiple_choice_h
=
self
.
dropout
(
multiple_choice_h
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
...
@@ -362,13 +362,9 @@ class GPT2PreTrainedModel(PreTrainedModel):
...
@@ -362,13 +362,9 @@ class GPT2PreTrainedModel(PreTrainedModel):
"""
"""
config_class
=
GPT2Config
config_class
=
GPT2Config
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_gpt2
load_tf_weights
=
load_tf_weights_in_gpt2
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
GPT2PreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
...
@@ -403,126 +399,9 @@ class GPT2PreTrainedModel(PreTrainedModel):
...
@@ -403,126 +399,9 @@ class GPT2PreTrainedModel(PreTrainedModel):
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific GPT2 class
*inputs, **kwargs: additional input for the specific GPT2 class
"""
"""
# state_dict = kwargs.get('state_dict', None)
num_special_tokens
=
kwargs
.
pop
(
'num_special_tokens'
,
None
)
# kwargs.pop('state_dict', None)
# cache_dir = kwargs.get('cache_dir', None)
model
=
PreTrainedModel
.
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
# kwargs.pop('cache_dir', None)
# from_tf = kwargs.get('from_tf', False)
# kwargs.pop('from_tf', None)
num_special_tokens
=
kwargs
.
get
(
'num_special_tokens'
,
None
)
kwargs
.
pop
(
'num_special_tokens'
,
None
)
# if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
# archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
# config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
# else:
# archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
# config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# # redirect to the cache, if necessary
# try:
# resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
# except EnvironmentError:
# if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
# logger.error(
# "Couldn't reach server at '{}' to download pretrained weights.".format(
# archive_file))
# else:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find file {} "
# "at this path or url.".format(
# pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
# archive_file
# )
# )
# return None
# try:
# resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
# except EnvironmentError:
# if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
# logger.error(
# "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
# config_file))
# else:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find file {} "
# "at this path or url.".format(
# pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
# config_file
# )
# )
# return None
# if resolved_archive_file == archive_file and resolved_config_file == config_file:
# logger.info("loading weights file {}".format(archive_file))
# logger.info("loading configuration file {}".format(config_file))
# else:
# logger.info("loading weights file {} from cache at {}".format(
# archive_file, resolved_archive_file))
# logger.info("loading configuration file {} from cache at {}".format(
# config_file, resolved_config_file))
# # Load config
# config = GPT2Config.from_json_file(resolved_config_file)
# logger.info("Model config {}".format(config))
# # Instantiate model.
# model = cls(config, *inputs, **kwargs)
# if state_dict is None and not from_tf:
# state_dict = torch.load(resolved_archive_file, map_location='cpu')
# if from_tf:
# # Directly load from a TensorFlow checkpoint (stored as NumPy array)
# return load_tf_weights_in_gpt2(model, resolved_archive_file)
# old_keys = []
# new_keys = []
# for key in state_dict.keys():
# new_key = None
# if key.endswith(".g"):
# new_key = key[:-2] + ".weight"
# elif key.endswith(".b"):
# new_key = key[:-2] + ".bias"
# elif key.endswith(".w"):
# new_key = key[:-2] + ".weight"
# if new_key:
# old_keys.append(key)
# new_keys.append(new_key)
# for old_key, new_key in zip(old_keys, new_keys):
# state_dict[new_key] = state_dict.pop(old_key)
# missing_keys = []
# unexpected_keys = []
# error_msgs = []
# # copy state_dict so _load_from_state_dict can modify it
# metadata = getattr(state_dict, "_metadata", None)
# state_dict = state_dict.copy()
# if metadata is not None:
# state_dict._metadata = metadata
# def load(module, prefix=""):
# local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# module._load_from_state_dict(
# state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
# )
# for name, child in module._modules.items():
# if child is not None:
# load(child, prefix + name + ".")
# start_model = model
# if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
# start_model = model.transformer
# load(start_model, prefix="")
# if len(missing_keys) > 0:
# logger.info(
# "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
# )
# if len(unexpected_keys) > 0:
# logger.info(
# "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
# )
# if len(error_msgs) > 0:
# raise RuntimeError(
# "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
# )
# Add additional embeddings for special tokens if needed
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
# This step also make sure we are still sharing the output and input embeddings after loading weights
...
@@ -553,8 +432,6 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -553,8 +432,6 @@ class GPT2Model(GPT2PreTrainedModel):
Params:
Params:
`config`: a GPT2Config class instance with the configuration to build a new model
`config`: a GPT2Config class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
...
@@ -591,14 +468,15 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -591,14 +468,15 @@ class GPT2Model(GPT2PreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
GPT2Model
,
self
).
__init__
(
config
)
super
(
GPT2Model
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_attentions
=
config
.
output_attentions
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
,
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
keep_multihead_output
=
keep_multihead_output
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -618,19 +496,13 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -618,19 +496,13 @@ class GPT2Model(GPT2PreTrainedModel):
# Copy word embeddings from the previous weights
# Copy word embeddings from the previous weights
self
.
wte
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
self
.
wte
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
def
prune_heads
(
self
,
heads_to_prune
):
def
_
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
h
.
attn
.
multihead_output
for
h
in
self
.
h
]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
past
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
past
=
None
,
head_mask
=
None
):
if
past
is
None
:
if
past
is
None
:
past_length
=
0
past_length
=
0
...
@@ -675,20 +547,32 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -675,20 +547,32 @@ class GPT2Model(GPT2PreTrainedModel):
all_attentions
=
[]
all_attentions
=
[]
all_hidden_states
=
[]
all_hidden_states
=
[]
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
[
i
])
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
[
i
])
if
self
.
output_attentions
:
hidden_states
,
present
=
outputs
[:
2
]
attentions
,
hidden_states
,
present
=
outputs
all_attentions
.
append
(
attentions
)
else
:
hidden_states
,
present
=
outputs
presents
.
append
(
present
)
presents
.
append
(
present
)
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
2
])
hidden_states
=
self
.
ln_f
(
hidden_states
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
hidden_states
=
hidden_states
.
view
(
*
output_shape
)
# Add last hidden state
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
outputs
=
[
hidden_states
,
presents
]
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
all_attentions
,
all_hidden_states
,
presents
# let the number of heads free (-1) so we can extract attention even after head pruning
return
all_hidden_states
,
presents
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
all_attentions
=
list
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
.
append
(
all_attentions
)
return
outputs
# last hidden state, presents, (all hidden_states), (attentions)
class
GPT2LMHeadModel
(
GPT2PreTrainedModel
):
class
GPT2LMHeadModel
(
GPT2PreTrainedModel
):
...
@@ -740,10 +624,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -740,10 +624,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
GPT2LMHeadModel
,
self
).
__init__
(
config
)
super
(
GPT2LMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
,
self
.
transformer
=
GPT2Model
(
config
)
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -756,14 +639,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -756,14 +639,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
hidden_states
=
transformer_outputs
[
0
]
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
hidden_states
,
presents
=
transformer_output
hidden_states
=
hidden_states
[
-
1
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
[
lm_logits
]
+
transformer_outputs
[
1
:]
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
...
@@ -772,10 +653,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -772,10 +653,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
shift_labels
.
view
(
-
1
))
return
loss
outputs
=
[
loss
]
+
outputs
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
,
presents
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
return
lm_logits
,
presents
class
GPT2DoubleHeadsModel
(
GPT2PreTrainedModel
):
class
GPT2DoubleHeadsModel
(
GPT2PreTrainedModel
):
...
@@ -832,12 +712,12 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -832,12 +712,12 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
GPT2DoubleHeadsModel
,
self
).
__init__
(
config
)
super
(
GPT2DoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
,
self
.
transformer
=
GPT2Model
(
config
)
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
multiple_choice_head
=
GPT2MultipleChoiceHead
(
config
)
self
.
multiple_choice_head
=
GPT2MultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
,
predict_special_tokens
=
True
):
def
set_num_special_tokens
(
self
,
num_special_tokens
,
predict_special_tokens
=
True
):
...
@@ -848,28 +728,26 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -848,28 +728,26 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
,
head_mask
=
None
):
position_ids
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
hidden_states
=
transformer_outputs
[
0
]
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
hidden_states
,
presents
=
transformer_output
hidden_states
=
hidden_states
[
-
1
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
outputs
=
[
lm_logits
,
mc_logits
]
+
transformer_outputs
[
1
:]
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
)))
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
if
mc_labels
is
not
None
:
shift_labels
.
view
(
-
1
))
loss_fct
=
CrossEntropyLoss
()
outputs
=
[
loss
]
+
outputs
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
if
losses
:
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
return
losses
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
,
mc_logits
,
presents
return
lm_logits
,
mc_logits
,
presents
pytorch_pretrained_bert/modeling_openai.py
View file @
1484d67d
...
@@ -147,7 +147,8 @@ class OpenAIGPTConfig(PretrainedConfig):
...
@@ -147,7 +147,8 @@ class OpenAIGPTConfig(PretrainedConfig):
attn_pdrop
=
0.1
,
attn_pdrop
=
0.1
,
layer_norm_epsilon
=
1e-5
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
predict_special_tokens
=
True
predict_special_tokens
=
True
,
**
kwargs
):
):
"""Constructs OpenAIGPTConfig.
"""Constructs OpenAIGPTConfig.
...
@@ -172,6 +173,8 @@ class OpenAIGPTConfig(PretrainedConfig):
...
@@ -172,6 +173,8 @@ class OpenAIGPTConfig(PretrainedConfig):
initializing all weight matrices.
initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
predict_special_tokens: should we predict special tokens (when the model has a LM head)
"""
"""
super
(
OpenAIGPTConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
...
@@ -205,7 +208,7 @@ class OpenAIGPTConfig(PretrainedConfig):
...
@@ -205,7 +208,7 @@ class OpenAIGPTConfig(PretrainedConfig):
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
):
super
(
Attention
,
self
).
__init__
()
super
(
Attention
,
self
).
__init__
()
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
...
@@ -215,9 +218,7 @@ class Attention(nn.Module):
...
@@ -215,9 +218,7 @@ class Attention(nn.Module):
self
.
split_size
=
n_state
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
scale
=
scale
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
...
@@ -256,9 +257,10 @@ class Attention(nn.Module):
...
@@ -256,9 +257,10 @@ class Attention(nn.Module):
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
w
=
w
*
head_mask
w
=
w
*
head_mask
outputs
=
[
torch
.
matmul
(
w
,
v
)]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
w
,
torch
.
matmul
(
w
,
v
)
outputs
.
append
(
w
)
return
torch
.
matmul
(
w
,
v
)
return
outputs
def
merge_heads
(
self
,
x
):
def
merge_heads
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
...
@@ -280,19 +282,15 @@ class Attention(nn.Module):
...
@@ -280,19 +282,15 @@ class Attention(nn.Module):
key
=
self
.
split_heads
(
key
,
k
=
True
)
key
=
self
.
split_heads
(
key
,
k
=
True
)
value
=
self
.
split_heads
(
value
)
value
=
self
.
split_heads
(
value
)
a
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
attn_outputs
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
if
self
.
keep_multihead_output
:
a
=
attn_outputs
[
0
]
self
.
multihead_output
=
a
self
.
multihead_output
.
retain_grad
()
if
self
.
output_attentions
:
attentions
,
a
=
a
a
=
self
.
merge_heads
(
a
)
a
=
self
.
merge_heads
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
resid_dropout
(
a
)
a
=
self
.
resid_dropout
(
a
)
if
self
.
output_attentions
:
return
attentions
,
a
outputs
=
[
a
]
+
attn_outputs
[
1
:]
return
a
return
outputs
# a, (attentions)
class
MLP
(
nn
.
Module
):
class
MLP
(
nn
.
Module
):
...
@@ -311,25 +309,24 @@ class MLP(nn.Module):
...
@@ -311,25 +309,24 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
):
super
(
Block
,
self
).
__init__
()
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
nx
=
config
.
n_embd
self
.
output_attentions
=
output_attentions
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
,
keep_multihead_output
)
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
x
,
head_mask
=
None
):
def
forward
(
self
,
x
,
head_mask
=
None
):
a
=
self
.
attn
(
x
,
head_mask
=
head_mask
)
a
ttn_outputs
=
self
.
attn
(
x
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
a
=
attn_outputs
[
0
]
attentions
,
a
=
a
n
=
self
.
ln_1
(
x
+
a
)
n
=
self
.
ln_1
(
x
+
a
)
m
=
self
.
mlp
(
n
)
m
=
self
.
mlp
(
n
)
h
=
self
.
ln_2
(
n
+
m
)
h
=
self
.
ln_2
(
n
+
m
)
if
self
.
output_attentions
:
return
attentions
,
h
outputs
=
[
h
]
+
attn_outputs
[
1
:]
return
h
return
outputs
class
OpenAIGPTLMHead
(
nn
.
Module
):
class
OpenAIGPTLMHead
(
nn
.
Module
):
...
@@ -368,11 +365,16 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
...
@@ -368,11 +365,16 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
def
forward
(
self
,
hidden_states
,
mc_token_ids
):
def
forward
(
self
,
hidden_states
,
mc_token_ids
=
None
):
# Classification logits
""" Extract classification token hidden state and project it using self.linear
# hidden_state (bsz, num_choices, seq_length, hidden_size)
hidden_state: hidden state of shape (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices)
mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
mc_token_ids
=
mc_token_ids
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
hidden_states
.
size
(
-
1
))
if mc_token_ids=None we take the last token of the sequence as classification token
"""
if
mc_token_ids
is
None
:
mc_token_ids
=
torch
.
full_like
(
hidden_states
[:,
:,
:
1
,
:],
hidden_states
.
shape
[
2
]
-
1
,
dtype
=
torch
.
long
)
else
:
mc_token_ids
=
mc_token_ids
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
hidden_states
.
size
(
-
1
))
# (bsz, num_choices, 1, hidden_size)
# (bsz, num_choices, 1, hidden_size)
multiple_choice_h
=
hidden_states
.
gather
(
2
,
mc_token_ids
).
squeeze
(
2
)
multiple_choice_h
=
hidden_states
.
gather
(
2
,
mc_token_ids
).
squeeze
(
2
)
# (bsz, num_choices, hidden_size)
# (bsz, num_choices, hidden_size)
...
@@ -388,13 +390,9 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
...
@@ -388,13 +390,9 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
"""
"""
config_class
=
OpenAIGPTConfig
config_class
=
OpenAIGPTConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_openai_gpt
load_tf_weights
=
load_tf_weights_in_openai_gpt
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
...
@@ -495,14 +493,15 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -495,14 +493,15 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
,
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
keep_multihead_output
=
keep_multihead_output
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -521,19 +520,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -521,19 +520,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Copy word embeddings from the previous weights
# Copy word embeddings from the previous weights
self
.
tokens_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
self
.
tokens_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
def
prune_heads
(
self
,
heads_to_prune
):
def
_
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
h
.
attn
.
multihead_output
for
h
in
self
.
h
]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
head_mask
=
None
):
if
position_ids
is
None
:
if
position_ids
is
None
:
# This was used when we had a single embedding matrice from position and token embeddings
# This was used when we had a single embedding matrice from position and token embeddings
...
@@ -574,19 +567,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -574,19 +567,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
all_attentions
=
[]
all_attentions
=
[]
all_hidden_states
=
[
hidden_states
.
view
(
*
output_shape
)
]
all_hidden_states
=
[]
for
i
,
block
in
enumerate
(
self
.
h
):
for
i
,
block
in
enumerate
(
self
.
h
):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
outputs
=
block
(
hidden_states
,
head_mask
[
i
])
outputs
=
block
(
hidden_states
,
head_mask
[
i
])
hidden_states
=
outputs
[
0
]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
,
hidden_states
=
outputs
all_
attentions
.
append
(
outputs
[
1
])
all_attentions
.
append
(
attentions
)
else
:
# Add last layer
hidden_states
=
outputs
if
self
.
output_
hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
outputs
=
[
hidden_states
.
view
(
*
output_shape
)]
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
all_attentions
,
all_hidden_states
outputs
.
append
(
all_attentions
)
return
all
_
hidden
_
states
return
outputs
# last hidden state, (
all
hidden
states
), (all attentions)
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
...
@@ -650,10 +650,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -650,10 +650,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
,
self
.
transformer
=
OpenAIGPTModel
(
config
)
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -666,12 +665,11 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -666,12 +665,11 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
head_mask
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
hidden_states
=
transformer_outputs
[
0
]
all_attentions
,
hidden_states
=
hidden_states
hidden_states
=
hidden_states
[
-
1
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
[
lm_logits
]
+
transformer_outputs
[
1
:]
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
...
@@ -680,10 +678,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -680,10 +678,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
shift_labels
.
view
(
-
1
))
return
loss
outputs
=
[
loss
]
+
outputs
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
return
outputs
# (loss), lm_logits, (all hidden states), (all attentions)
return
lm_logits
class
OpenAIGPTDoubleHeadsModel
(
OpenAIGPTPreTrainedModel
):
class
OpenAIGPTDoubleHeadsModel
(
OpenAIGPTPreTrainedModel
):
...
@@ -752,10 +749,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -752,10 +749,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
,
self
.
transformer
=
OpenAIGPTModel
(
config
)
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
config
)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -768,26 +764,26 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -768,26 +764,26 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
position_ids
=
None
,
head_mask
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
hidden_states
=
transformer_outputs
[
0
]
all_attentions
,
hidden_states
=
hidden_states
hidden_states
=
hidden_states
[
-
1
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
outputs
=
[
lm_logits
,
mc_logits
]
+
transformer_outputs
[
1
:]
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
)))
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
if
mc_labels
is
not
None
:
shift_labels
.
view
(
-
1
))
loss_fct
=
CrossEntropyLoss
()
outputs
=
[
loss
]
+
outputs
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
if
losses
:
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)
return
losses
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
,
mc_logits
return
lm_logits
,
mc_logits
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
1484d67d
...
@@ -209,7 +209,8 @@ class TransfoXLConfig(PretrainedConfig):
...
@@ -209,7 +209,8 @@ class TransfoXLConfig(PretrainedConfig):
init
=
"normal"
,
init
=
"normal"
,
init_range
=
0.01
,
init_range
=
0.01
,
proj_init_std
=
0.01
,
proj_init_std
=
0.01
,
init_std
=
0.02
):
init_std
=
0.02
,
**
kwargs
):
"""Constructs TransfoXLConfig.
"""Constructs TransfoXLConfig.
Args:
Args:
...
@@ -244,6 +245,8 @@ class TransfoXLConfig(PretrainedConfig):
...
@@ -244,6 +245,8 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std: parameters initialized by N(0, init_std)
proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
"""
"""
super
(
TransfoXLConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
...
@@ -287,6 +290,7 @@ class TransfoXLConfig(PretrainedConfig):
...
@@ -287,6 +290,7 @@ class TransfoXLConfig(PretrainedConfig):
"or the path to a pretrained model config file (str)"
)
"or the path to a pretrained model config file (str)"
)
class
PositionalEmbedding
(
nn
.
Module
):
class
PositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
demb
):
def
__init__
(
self
,
demb
):
super
(
PositionalEmbedding
,
self
).
__init__
()
super
(
PositionalEmbedding
,
self
).
__init__
()
...
@@ -306,6 +310,7 @@ class PositionalEmbedding(nn.Module):
...
@@ -306,6 +310,7 @@ class PositionalEmbedding(nn.Module):
return
pos_emb
[:,
None
,:]
return
pos_emb
[:,
None
,:]
class
PositionwiseFF
(
nn
.
Module
):
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
PositionwiseFF
,
self
).
__init__
()
super
(
PositionwiseFF
,
self
).
__init__
()
...
@@ -341,11 +346,14 @@ class PositionwiseFF(nn.Module):
...
@@ -341,11 +346,14 @@ class PositionwiseFF(nn.Module):
return
output
return
output
class
MultiHeadAttn
(
nn
.
Module
):
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
):
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
super
(
MultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
d_head
=
d_head
...
@@ -371,7 +379,7 @@ class MultiHeadAttn(nn.Module):
...
@@ -371,7 +379,7 @@ class MultiHeadAttn(nn.Module):
self
.
r_r_bias
=
r_r_bias
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
self
.
r_w_bias
=
r_w_bias
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
##### multihead attention
##### multihead attention
# [hlen x bsz x n_head x d_head]
# [hlen x bsz x n_head x d_head]
...
@@ -404,6 +412,10 @@ class MultiHeadAttn(nn.Module):
...
@@ -404,6 +412,10 @@ class MultiHeadAttn(nn.Module):
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
=
attn_vec
.
contiguous
().
view
(
...
@@ -415,19 +427,23 @@ class MultiHeadAttn(nn.Module):
...
@@ -415,19 +427,23 @@ class MultiHeadAttn(nn.Module):
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
##### residual connection
##### residual connection
output
=
h
+
attn_out
output
s
=
[
h
+
attn_out
]
else
:
else
:
##### residual connection + layer normalization
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
h
+
attn_out
)
output
s
=
[
self
.
layer_norm
(
h
+
attn_out
)
]
return
output
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
RelMultiHeadAttn
(
nn
.
Module
):
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
):
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
super
(
RelMultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
d_head
=
d_head
...
@@ -506,7 +522,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -506,7 +522,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
qlen
,
rlen
,
bsz
=
w
.
size
(
0
),
r
.
size
(
0
),
w
.
size
(
1
)
qlen
,
rlen
,
bsz
=
w
.
size
(
0
),
r
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
if
mems
is
not
None
:
...
@@ -561,6 +577,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -561,6 +577,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
...
@@ -574,18 +594,21 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -574,18 +594,21 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
##### residual connection
##### residual connection
output
=
w
+
attn_out
output
s
=
[
w
+
attn_out
]
else
:
else
:
##### residual connection + layer normalization
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
w
+
attn_out
)
output
s
=
[
self
.
layer_norm
(
w
+
attn_out
)
]
return
output
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
RelLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
class
RelLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
RelLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
w
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
w
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
# r_emb: [klen, n_head, d_head], used for term B
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
# r_bias: [klen, n_head], used for term D
...
@@ -646,6 +669,9 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -646,6 +669,9 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
...
@@ -659,12 +685,17 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -659,12 +685,17 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
##### residual connection
##### residual connection
output
=
w
+
attn_out
output
s
=
[
w
+
attn_out
]
else
:
else
:
##### residual connection + layer normalization
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
w
+
attn_out
)
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
return
output
class
DecoderLayer
(
nn
.
Module
):
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
...
@@ -674,13 +705,15 @@ class DecoderLayer(nn.Module):
...
@@ -674,13 +705,15 @@ class DecoderLayer(nn.Module):
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
attn_
output
s
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
,
head_mask
=
head_mask
)
output
=
self
.
pos_ff
(
output
)
ff_
output
=
self
.
pos_ff
(
attn_
output
s
[
0
]
)
return
output
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelLearnableDecoderLayer
(
nn
.
Module
):
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -692,14 +725,16 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -692,14 +725,16 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_
output
s
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
,
head_mask
=
head_mask
)
output
=
self
.
pos_ff
(
output
)
ff_
output
=
self
.
pos_ff
(
attn_
output
s
[
0
]
)
return
output
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -711,14 +746,17 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -711,14 +746,17 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
r
,
attn_
output
s
=
self
.
dec_attn
(
dec_inp
,
r
,
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
,
head_mask
=
head_mask
)
output
=
self
.
pos_ff
(
output
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
return
output
class
AdaptiveEmbedding
(
nn
.
Module
):
class
AdaptiveEmbedding
(
nn
.
Module
):
...
@@ -791,13 +829,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
...
@@ -791,13 +829,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
"""
"""
config_class
=
TransfoXLConfig
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_transfo_xl
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weight
(
self
,
weight
):
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
...
@@ -894,6 +928,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -894,6 +928,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
TransfoXLModel
,
self
).
__init__
(
config
)
super
(
TransfoXLModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
n_token
=
config
.
n_token
self
.
n_token
=
config
.
n_token
self
.
d_embed
=
config
.
d_embed
self
.
d_embed
=
config
.
d_embed
...
@@ -928,7 +965,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -928,7 +965,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
)
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
)
elif
config
.
attn_type
==
1
:
# learnable embeddings
elif
config
.
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
config
.
n_layer
):
for
i
in
range
(
config
.
n_layer
):
...
@@ -938,7 +976,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -938,7 +976,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
)
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
)
elif
config
.
attn_type
in
[
2
,
3
]:
# absolute embeddings
elif
config
.
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
config
.
n_layer
):
for
i
in
range
(
config
.
n_layer
):
...
@@ -947,7 +986,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -947,7 +986,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
)
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
)
self
.
same_length
=
config
.
same_length
self
.
same_length
=
config
.
same_length
...
@@ -965,17 +1005,21 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -965,17 +1005,21 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
elif
self
.
attn_type
==
3
:
# absolute deeper SA
elif
self
.
attn_type
==
3
:
# absolute deeper SA
self
.
r_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
r_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
backward_compatible
(
self
):
def
backward_compatible
(
self
):
self
.
sample_softmax
=
-
1
self
.
sample_softmax
=
-
1
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
self
.
tgt_len
=
tgt_len
self
.
tgt_len
=
tgt_len
self
.
mem_len
=
mem_len
self
.
mem_len
=
mem_len
self
.
ext_len
=
ext_len
self
.
ext_len
=
ext_len
def
_prune_heads
(
self
,
heads
):
logger
.
info
(
"Head pruning is not implemented for Transformer-XL model"
)
pass
def
init_mems
(
self
,
data
):
def
init_mems
(
self
,
data
):
if
self
.
mem_len
>
0
:
if
self
.
mem_len
>
0
:
mems
=
[]
mems
=
[]
...
@@ -1012,9 +1056,24 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -1012,9 +1056,24 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return
new_mems
return
new_mems
def
_forward
(
self
,
dec_inp
,
mems
=
None
):
def
_forward
(
self
,
dec_inp
,
mems
=
None
,
head_mask
=
None
):
qlen
,
bsz
=
dec_inp
.
size
()
qlen
,
bsz
=
dec_inp
.
size
()
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
)
head_mask
=
head_mask
.
expand
(
self
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
unsqueeze
(
1
)
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
n_layer
word_emb
=
self
.
word_emb
(
dec_inp
)
word_emb
=
self
.
word_emb
(
dec_inp
)
mlen
=
mems
[
0
].
size
(
0
)
if
mems
is
not
None
else
0
mlen
=
mems
[
0
].
size
(
0
)
if
mems
is
not
None
else
0
...
@@ -1033,6 +1092,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -1033,6 +1092,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
hids
=
[]
hids
=
[]
attentions
=
[]
if
self
.
attn_type
==
0
:
# default
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -1046,7 +1106,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -1046,7 +1106,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
pos_emb
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
layer_outputs
=
layer
(
core_out
,
pos_emb
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
1
:
# learnable
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
...
@@ -1058,8 +1122,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -1058,8 +1122,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
layer_outputs
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
2
:
# absolute
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -1074,8 +1142,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -1074,8 +1142,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
mems_i
+=
pos_emb
[:
mlen
]
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
3
:
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
...
@@ -1093,16 +1164,30 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -1093,16 +1164,30 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
core_out
=
self
.
drop
(
core_out
)
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
return
core_out
,
new_mems
# We transpose back here to shape [bsz, len, hidden_dim]
outputs
=
[
core_out
.
transpose
(
0
,
1
).
contiguous
(),
new_mems
]
def
forward
(
self
,
input_ids
,
mems
=
None
):
if
self
.
output_hidden_states
:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids
.
append
(
core_out
)
hids
=
list
(
t
.
transpose
(
0
,
1
).
contiguous
()
for
t
in
hids
)
outputs
.
append
(
hids
)
if
self
.
output_attentions
:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
):
""" Params:
""" Params:
input_ids :: [bsz, len]
input_ids :: [bsz, len]
mems :: optional mems from previous forwar passes (or init_mems)
mems :: optional mems from previous forwar passes (or init_mems)
...
@@ -1122,11 +1207,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -1122,11 +1207,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
if
mems
is
None
:
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
mems
=
self
.
init_mems
(
input_ids
)
last_hidden
,
new_mem
s
=
self
.
_forward
(
input_ids
,
mems
=
mems
)
output
s
=
self
.
_forward
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
# We transpose back here to shape [bsz, len, hidden_dim]
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
last_hidden
=
last_hidden
.
transpose
(
0
,
1
).
contiguous
()
return
(
last_hidden
,
new_mems
)
class
TransfoXLLMHeadModel
(
TransfoXLPreTrainedModel
):
class
TransfoXLLMHeadModel
(
TransfoXLPreTrainedModel
):
...
@@ -1218,7 +1301,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -1218,7 +1301,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def
init_mems
(
self
,
data
):
def
init_mems
(
self
,
data
):
return
self
.
transformer
.
init_mems
(
data
)
return
self
.
transformer
.
init_mems
(
data
)
def
forward
(
self
,
input_ids
,
labels
=
None
,
mems
=
None
):
def
forward
(
self
,
input_ids
,
labels
=
None
,
mems
=
None
,
head_mask
=
None
):
""" Params:
""" Params:
input_ids :: [bsz, len]
input_ids :: [bsz, len]
labels :: [bsz, len]
labels :: [bsz, len]
...
@@ -1235,19 +1318,26 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -1235,19 +1318,26 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
bsz
=
input_ids
.
size
(
0
)
bsz
=
input_ids
.
size
(
0
)
tgt_len
=
input_ids
.
size
(
1
)
tgt_len
=
input_ids
.
size
(
1
)
last_hidden
,
new_mem
s
=
self
.
transformer
(
input_ids
,
mems
)
transformer_output
s
=
self
.
transformer
(
input_ids
,
mems
,
head_mask
)
last_hidden
=
transformer_outputs
[
0
]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
outputs
=
transformer_outputs
[
1
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
if
self
.
sample_softmax
>
0
and
self
.
training
:
assert
self
.
config
.
tie_weight
assert
self
.
config
.
tie_weight
logit
=
sample_logits
(
self
.
transformer
.
word_emb
,
self
.
out_layer
.
bias
,
labels
,
pred_hid
,
self
.
sampler
)
logit
=
sample_logits
(
self
.
transformer
.
word_emb
,
self
.
out_layer
.
bias
,
labels
,
pred_hid
,
self
.
sampler
)
softmax_output
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
softmax_output
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
outputs
=
[
softmax_output
]
+
outputs
if
labels
is
not
None
:
# TODO: This is not implemented
raise
NotImplementedError
else
:
else
:
softmax_output
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
labels
)
softmax_output
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
labels
)
if
labels
is
None
:
if
labels
is
None
:
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
,
-
1
)
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
,
-
1
)
outputs
=
[
softmax_output
]
+
outputs
else
:
else
:
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
)
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
)
outputs
=
[
softmax_output
,
None
]
+
outputs
# We transpose back
return
outputs
# (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
return
(
softmax_output
,
new_mems
)
pytorch_pretrained_bert/modeling_xlm.py
View file @
1484d67d
...
@@ -73,6 +73,7 @@ class XLMConfig(PretrainedConfig):
...
@@ -73,6 +73,7 @@ class XLMConfig(PretrainedConfig):
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size_or_config_json_file
,
vocab_size_or_config_json_file
,
causal
=
True
,
d_model
=
1024
,
d_model
=
1024
,
n_layer
=
24
,
n_layer
=
24
,
n_head
=
16
,
n_head
=
16
,
...
@@ -145,6 +146,7 @@ class XLMConfig(PretrainedConfig):
...
@@ -145,6 +146,7 @@ class XLMConfig(PretrainedConfig):
self
.
__dict__
[
key
]
=
value
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
n_token
=
vocab_size_or_config_json_file
self
.
n_token
=
vocab_size_or_config_json_file
self
.
causal
=
causal
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
n_layer
=
n_layer
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
n_head
=
n_head
...
@@ -396,7 +398,6 @@ class XLMPreTrainedModel(PreTrainedModel):
...
@@ -396,7 +398,6 @@ class XLMPreTrainedModel(PreTrainedModel):
"""
"""
config_class
=
XLMConfig
config_class
=
XLMConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
None
load_tf_weights
=
None
base_model_prefix
=
"xlm"
base_model_prefix
=
"xlm"
...
@@ -429,7 +430,7 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -429,7 +430,7 @@ class XLMModel(XLMPreTrainedModel):
'hidden_dim'
,
'dropout'
,
'attention_dropout'
,
'asm'
,
'hidden_dim'
,
'dropout'
,
'attention_dropout'
,
'asm'
,
'asm_cutoffs'
,
'asm_div_value'
]
'asm_cutoffs'
,
'asm_div_value'
]
def
__init__
(
self
,
params
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
#, dico, is_encoder, with_output):
def
__init__
(
self
,
params
,
output_attentions
=
False
,
output_hidden_states
=
False
):
#, dico, is_encoder, with_output):
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
Paper: https://arxiv.org/abs/1901.07291
Paper: https://arxiv.org/abs/1901.07291
Original code: https://github.com/facebookresearch/XLM
Original code: https://github.com/facebookresearch/XLM
...
@@ -483,11 +484,13 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -483,11 +484,13 @@ class XLMModel(XLMPreTrainedModel):
"""
"""
super
(
XLMModel
,
self
).
__init__
(
params
)
super
(
XLMModel
,
self
).
__init__
(
params
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
# encoder / decoder, output layer
# encoder / decoder, output layer
# self.is_encoder = is_encoder
# self.is_encoder = is_encoder
# self.is_decoder = not is_encoder
# self.is_decoder = not is_encoder
# self.with_output = with_output
# self.with_output = with_output
self
.
causal
=
params
.
causal
# dictionary / languages
# dictionary / languages
self
.
n_langs
=
params
.
n_langs
self
.
n_langs
=
params
.
n_langs
...
@@ -536,63 +539,45 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -536,63 +539,45 @@ class XLMModel(XLMPreTrainedModel):
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
dropout
=
self
.
dropout
,
gelu_activation
=
params
.
gelu_activation
))
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
dropout
=
self
.
dropout
,
gelu_activation
=
params
.
gelu_activation
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
# output layer
def
forward
(
self
,
x
,
lengths
,
positions
=
None
,
langs
=
None
,
cache
=
None
,
head_mask
=
None
):
# src_enc=None, src_len=None,
# if self.with_output:
# self.pred_layer = PredLayer(params)
# if params.share_inout_emb:
# self.pred_layer.proj.weight = self.embeddings.weight
# def forward(self, mode, **kwargs):
# """
# Forward function with different forward modes.
# ### Small hack to handle PyTorch distributed.
# """
# if mode == 'fwd':
# return self.fwd(**kwargs)
# elif mode == 'predict':
# return self.predict(**kwargs)
# else:
# raise Exception("Unknown mode: %s" % mode)
def
forward
(
self
,
x
,
lengths
,
causal
,
src_enc
=
None
,
src_len
=
None
,
positions
=
None
,
langs
=
None
,
cache
=
None
):
"""
"""
Inputs:
Inputs:
`x` LongTensor(slen
, bs
), containing word indices
`x` LongTensor(
bs,
slen), containing word indices
`lengths` LongTensor(bs), containing the length of each sentence
`lengths` LongTensor(bs), containing the length of each sentence
`causal` Boolean, if True, the attention is only done over previous hidden states
`causal` Boolean, if True, the attention is only done over previous hidden states
`positions` LongTensor(slen
, bs
), containing word positions
`positions` LongTensor(
bs,
slen), containing word positions
`langs` LongTensor(slen
, bs
), containing language IDs
`langs` LongTensor(
bs,
slen), containing language IDs
"""
"""
# lengths = (x != self.pad_index).float().sum(dim=1)
# lengths = (x != self.pad_index).float().sum(dim=1)
# mask = x != self.pad_index
# mask = x != self.pad_index
# check inputs
# check inputs
slen
,
bs
=
x
.
size
()
bs
,
slen
=
x
.
size
()
assert
lengths
.
size
(
0
)
==
bs
assert
lengths
.
size
(
0
)
==
bs
assert
lengths
.
max
().
item
()
<=
slen
assert
lengths
.
max
().
item
()
<=
slen
x
=
x
.
transpose
(
0
,
1
)
# batch size as dimension 0
#
x = x.transpose(0, 1) # batch size as dimension 0
assert
(
src_enc
is
None
)
==
(
src_len
is
None
)
#
assert (src_enc is None) == (src_len is None)
if
src_enc
is
not
None
:
#
if src_enc is not None:
assert
self
.
is_decoder
#
assert self.is_decoder
assert
src_enc
.
size
(
0
)
==
bs
#
assert src_enc.size(0) == bs
# generate masks
# generate masks
mask
,
attn_mask
=
get_masks
(
slen
,
lengths
,
causal
)
mask
,
attn_mask
=
get_masks
(
slen
,
lengths
,
self
.
causal
)
if
self
.
is_decoder
and
src_enc
is
not
None
:
#
if self.is_decoder and src_enc is not None:
src_mask
=
torch
.
arange
(
src_len
.
max
(),
dtype
=
torch
.
long
,
device
=
lengths
.
device
)
<
src_len
[:,
None
]
#
src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# positions
# positions
if
positions
is
None
:
if
positions
is
None
:
positions
=
x
.
new
(
slen
).
long
()
positions
=
x
.
new
(
slen
).
long
()
positions
=
torch
.
arange
(
slen
,
out
=
positions
).
unsqueeze
(
0
)
positions
=
torch
.
arange
(
slen
,
out
=
positions
).
unsqueeze
(
0
)
else
:
else
:
assert
positions
.
size
()
==
(
slen
,
bs
)
assert
positions
.
size
()
==
(
bs
,
slen
)
#
(slen, bs)
positions
=
positions
.
transpose
(
0
,
1
)
#
positions = positions.transpose(0, 1)
# langs
# langs
if
langs
is
not
None
:
if
langs
is
not
None
:
assert
langs
.
size
()
==
(
slen
,
bs
)
assert
langs
.
size
()
==
(
bs
,
slen
)
#
(slen, bs)
langs
=
langs
.
transpose
(
0
,
1
)
#
langs = langs.transpose(0, 1)
# do not recompute cached elements
# do not recompute cached elements
if
cache
is
not
None
:
if
cache
is
not
None
:
...
@@ -614,620 +599,50 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -614,620 +599,50 @@ class XLMModel(XLMPreTrainedModel):
tensor
*=
mask
.
unsqueeze
(
-
1
).
to
(
tensor
.
dtype
)
tensor
*=
mask
.
unsqueeze
(
-
1
).
to
(
tensor
.
dtype
)
# transformer layers
# transformer layers
hidden_states
=
[]
attentions
=
[]
for
i
in
range
(
self
.
n_layers
):
for
i
in
range
(
self
.
n_layers
):
if
self
.
output_hidden_states
:
hidden_states
.
append
(
tensor
)
# self attention
# self attention
attn
=
self
.
attentions
[
i
](
tensor
,
attn_mask
,
cache
=
cache
)
attn_outputs
=
self
.
attentions
[
i
](
tensor
,
attn_mask
,
cache
=
cache
,
head_mask
=
head_mask
[
i
])
attn
=
attn_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
attn_outputs
[
1
])
attn
=
F
.
dropout
(
attn
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn
=
F
.
dropout
(
attn
,
p
=
self
.
dropout
,
training
=
self
.
training
)
tensor
=
tensor
+
attn
tensor
=
tensor
+
attn
tensor
=
self
.
layer_norm1
[
i
](
tensor
)
tensor
=
self
.
layer_norm1
[
i
](
tensor
)
# encoder attention (for decoder only)
# encoder attention (for decoder only)
if
self
.
is_decoder
and
src_enc
is
not
None
:
#
if self.is_decoder and src_enc is not None:
attn
=
self
.
encoder_attn
[
i
](
tensor
,
src_mask
,
kv
=
src_enc
,
cache
=
cache
)
#
attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
attn
=
F
.
dropout
(
attn
,
p
=
self
.
dropout
,
training
=
self
.
training
)
#
attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor
=
tensor
+
attn
#
tensor = tensor + attn
tensor
=
self
.
layer_norm15
[
i
](
tensor
)
#
tensor = self.layer_norm15[i](tensor)
# FFN
# FFN
tensor
=
tensor
+
self
.
ffns
[
i
](
tensor
)
tensor
=
tensor
+
self
.
ffns
[
i
](
tensor
)
tensor
=
self
.
layer_norm2
[
i
](
tensor
)
tensor
=
self
.
layer_norm2
[
i
](
tensor
)
tensor
*=
mask
.
unsqueeze
(
-
1
).
to
(
tensor
.
dtype
)
tensor
*=
mask
.
unsqueeze
(
-
1
).
to
(
tensor
.
dtype
)
# Add last hidden state
if
self
.
output_hidden_states
:
hidden_states
.
append
(
tensor
)
# update cache length
# update cache length
if
cache
is
not
None
:
if
cache
is
not
None
:
cache
[
'slen'
]
+=
tensor
.
size
(
1
)
cache
[
'slen'
]
+=
tensor
.
size
(
1
)
# move back sequence length to dimension 0
# move back sequence length to dimension 0
tensor
=
tensor
.
transpose
(
0
,
1
)
# tensor = tensor.transpose(0, 1)
return
tensor
def
predict
(
self
,
tensor
,
pred_mask
,
y
,
get_scores
):
"""
Given the last hidden state, compute word scores and/or the loss.
`pred_mask` is a ByteTensor of shape (slen, bs), filled with 1 when
we need to predict a word
`y` is a LongTensor of shape (pred_mask.sum(),)
`get_scores` is a boolean specifying whether we need to return scores
"""
masked_tensor
=
tensor
[
pred_mask
.
unsqueeze
(
-
1
).
expand_as
(
tensor
)].
view
(
-
1
,
self
.
dim
)
scores
,
loss
=
self
.
pred_layer
(
masked_tensor
,
y
,
get_scores
)
return
scores
,
loss
def
generate
(
self
,
src_enc
,
src_len
,
tgt_lang_id
,
max_len
=
200
,
sample_temperature
=
None
):
"""
Decode a sentence given initial start.
`x`:
- LongTensor(bs, slen)
<EOS> W1 W2 W3 <EOS> <PAD>
<EOS> W1 W2 W3 W4 <EOS>
`lengths`:
- LongTensor(bs) [5, 6]
`positions`:
- False, for regular "arange" positions (LM)
- True, to reset positions from the new generation (MT)
`langs`:
- must be None if the model only supports one language
- lang_id if only one language is involved (LM)
- (lang_id1, lang_id2) if two languages are involved (MT)
"""
# input batch
bs
=
len
(
src_len
)
assert
src_enc
.
size
(
0
)
==
bs
# generated sentences
generated
=
src_len
.
new
(
max_len
,
bs
)
# upcoming output
generated
.
fill_
(
self
.
pad_index
)
# fill upcoming ouput with <PAD>
generated
[
0
].
fill_
(
self
.
eos_index
)
# we use <EOS> for <BOS> everywhere
# positions
positions
=
src_len
.
new
(
max_len
).
long
()
positions
=
torch
.
arange
(
max_len
,
out
=
positions
).
unsqueeze
(
1
).
expand
(
max_len
,
bs
)
# language IDs
langs
=
src_len
.
new
(
max_len
).
long
().
fill_
(
tgt_lang_id
)
langs
=
langs
.
unsqueeze
(
1
).
expand
(
max_len
,
bs
)
# current position / max lengths / length of generated sentences / unfinished sentences
cur_len
=
1
gen_len
=
src_len
.
clone
().
fill_
(
1
)
unfinished_sents
=
src_len
.
clone
().
fill_
(
1
)
# cache compute states
cache
=
{
'slen'
:
0
}
while
cur_len
<
max_len
:
# compute word scores
tensor
=
self
.
forward
(
'fwd'
,
x
=
generated
[:
cur_len
],
lengths
=
gen_len
,
positions
=
positions
[:
cur_len
],
langs
=
langs
[:
cur_len
],
causal
=
True
,
src_enc
=
src_enc
,
src_len
=
src_len
,
cache
=
cache
)
assert
tensor
.
size
()
==
(
1
,
bs
,
self
.
dim
)
tensor
=
tensor
.
data
[
-
1
,
:,
:]
# (bs, dim)
scores
=
self
.
pred_layer
.
get_scores
(
tensor
)
# (bs, n_words)
# select next words: sample or greedy
if
sample_temperature
is
None
:
next_words
=
torch
.
topk
(
scores
,
1
)[
1
].
squeeze
(
1
)
else
:
next_words
=
torch
.
multinomial
(
F
.
softmax
(
scores
/
sample_temperature
,
dim
=
1
),
1
).
squeeze
(
1
)
assert
next_words
.
size
()
==
(
bs
,)
# update generations / lengths / finished sentences / current length
generated
[
cur_len
]
=
next_words
*
unfinished_sents
+
self
.
pad_index
*
(
1
-
unfinished_sents
)
gen_len
.
add_
(
unfinished_sents
)
unfinished_sents
.
mul_
(
next_words
.
ne
(
self
.
eos_index
).
long
())
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if
unfinished_sents
.
max
()
==
0
:
break
# add <EOS> to unfinished sentences
if
cur_len
==
max_len
:
generated
[
-
1
].
masked_fill_
(
unfinished_sents
.
byte
(),
self
.
eos_index
)
# sanity check
assert
(
generated
==
self
.
eos_index
).
sum
()
==
2
*
bs
return
generated
[:
cur_len
],
gen_len
def
generate_beam
(
self
,
src_enc
,
src_len
,
tgt_lang_id
,
beam_size
,
length_penalty
,
early_stopping
,
max_len
=
200
):
"""
Decode a sentence given initial start.
`x`:
- LongTensor(bs, slen)
<EOS> W1 W2 W3 <EOS> <PAD>
<EOS> W1 W2 W3 W4 <EOS>
`lengths`:
- LongTensor(bs) [5, 6]
`positions`:
- False, for regular "arange" positions (LM)
- True, to reset positions from the new generation (MT)
`langs`:
- must be None if the model only supports one language
- lang_id if only one language is involved (LM)
- (lang_id1, lang_id2) if two languages are involved (MT)
"""
# check inputs
assert
src_enc
.
size
(
0
)
==
src_len
.
size
(
0
)
assert
beam_size
>=
1
# batch size / number of words
bs
=
len
(
src_len
)
n_words
=
self
.
n_words
# expand to beam size the source latent representations / source lengths
src_enc
=
src_enc
.
unsqueeze
(
1
).
expand
((
bs
,
beam_size
)
+
src_enc
.
shape
[
1
:]).
contiguous
().
view
((
bs
*
beam_size
,)
+
src_enc
.
shape
[
1
:])
src_len
=
src_len
.
unsqueeze
(
1
).
expand
(
bs
,
beam_size
).
contiguous
().
view
(
-
1
)
# generated sentences (batch with beam current hypotheses)
generated
=
src_len
.
new
(
max_len
,
bs
*
beam_size
)
# upcoming output
generated
.
fill_
(
self
.
pad_index
)
# fill upcoming ouput with <PAD>
generated
[
0
].
fill_
(
self
.
eos_index
)
# we use <EOS> for <BOS> everywhere
# generated hypotheses
generated_hyps
=
[
BeamHypotheses
(
beam_size
,
max_len
,
length_penalty
,
early_stopping
)
for
_
in
range
(
bs
)]
# positions
positions
=
src_len
.
new
(
max_len
).
long
()
positions
=
torch
.
arange
(
max_len
,
out
=
positions
).
unsqueeze
(
1
).
expand_as
(
generated
)
# language IDs
langs
=
positions
.
clone
().
fill_
(
tgt_lang_id
)
# scores for each sentence in the beam
beam_scores
=
src_enc
.
new
(
bs
,
beam_size
).
fill_
(
0
)
beam_scores
[:,
1
:]
=
-
1e9
beam_scores
=
beam_scores
.
view
(
-
1
)
# current position
cur_len
=
1
# cache compute states
cache
=
{
'slen'
:
0
}
# done sentences
done
=
[
False
for
_
in
range
(
bs
)]
while
cur_len
<
max_len
:
# compute word scores
tensor
=
self
.
forward
(
'fwd'
,
x
=
generated
[:
cur_len
],
lengths
=
src_len
.
new
(
bs
*
beam_size
).
fill_
(
cur_len
),
positions
=
positions
[:
cur_len
],
langs
=
langs
[:
cur_len
],
causal
=
True
,
src_enc
=
src_enc
,
src_len
=
src_len
,
cache
=
cache
)
assert
tensor
.
size
()
==
(
1
,
bs
*
beam_size
,
self
.
dim
)
tensor
=
tensor
.
data
[
-
1
,
:,
:]
# (bs * beam_size, dim)
scores
=
self
.
pred_layer
.
get_scores
(
tensor
)
# (bs * beam_size, n_words)
scores
=
F
.
log_softmax
(
scores
,
dim
=-
1
)
# (bs * beam_size, n_words)
assert
scores
.
size
()
==
(
bs
*
beam_size
,
n_words
)
# select next words with scores
_scores
=
scores
+
beam_scores
[:,
None
].
expand_as
(
scores
)
# (bs * beam_size, n_words)
_scores
=
_scores
.
view
(
bs
,
beam_size
*
n_words
)
# (bs, beam_size * n_words)
next_scores
,
next_words
=
torch
.
topk
(
_scores
,
2
*
beam_size
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
assert
next_scores
.
size
()
==
next_words
.
size
()
==
(
bs
,
2
*
beam_size
)
# next batch beam content
# list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
next_batch_beam
=
[]
# for each sentence
for
sent_id
in
range
(
bs
):
# if we are done with this sentence
done
[
sent_id
]
=
done
[
sent_id
]
or
generated_hyps
[
sent_id
].
is_done
(
next_scores
[
sent_id
].
max
().
item
())
if
done
[
sent_id
]:
next_batch_beam
.
extend
([(
0
,
self
.
pad_index
,
0
)]
*
beam_size
)
# pad the batch
continue
# next sentence beam content
next_sent_beam
=
[]
# next words for this sentence
for
idx
,
value
in
zip
(
next_words
[
sent_id
],
next_scores
[
sent_id
]):
# get beam and word IDs
beam_id
=
idx
//
n_words
word_id
=
idx
%
n_words
# end of sentence, or next word
if
word_id
==
self
.
eos_index
or
cur_len
+
1
==
max_len
:
generated_hyps
[
sent_id
].
add
(
generated
[:
cur_len
,
sent_id
*
beam_size
+
beam_id
].
clone
(),
value
.
item
())
else
:
next_sent_beam
.
append
((
value
,
word_id
,
sent_id
*
beam_size
+
beam_id
))
# the beam for next step is full
if
len
(
next_sent_beam
)
==
beam_size
:
break
# update next beam content
assert
len
(
next_sent_beam
)
==
0
if
cur_len
+
1
==
max_len
else
beam_size
if
len
(
next_sent_beam
)
==
0
:
next_sent_beam
=
[(
0
,
self
.
pad_index
,
0
)]
*
beam_size
# pad the batch
next_batch_beam
.
extend
(
next_sent_beam
)
assert
len
(
next_batch_beam
)
==
beam_size
*
(
sent_id
+
1
)
# sanity check / prepare next batch
assert
len
(
next_batch_beam
)
==
bs
*
beam_size
beam_scores
=
beam_scores
.
new
([
x
[
0
]
for
x
in
next_batch_beam
])
beam_words
=
generated
.
new
([
x
[
1
]
for
x
in
next_batch_beam
])
beam_idx
=
src_len
.
new
([
x
[
2
]
for
x
in
next_batch_beam
])
# re-order batch and internal states
generated
=
generated
[:,
beam_idx
]
generated
[
cur_len
]
=
beam_words
for
k
in
cache
.
keys
():
if
k
!=
'slen'
:
cache
[
k
]
=
(
cache
[
k
][
0
][
beam_idx
],
cache
[
k
][
1
][
beam_idx
])
# update current length
cur_len
=
cur_len
+
1
# stop when we are done with each sentence
if
all
(
done
):
break
# visualize hypotheses
# print([len(x) for x in generated_hyps], cur_len)
# globals().update( locals() );
# !import code; code.interact(local=vars())
# for ii in range(bs):
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
# print("")
# select the best hypotheses
tgt_len
=
src_len
.
new
(
bs
)
best
=
[]
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
best_hyp
=
max
(
hypotheses
.
hyp
,
key
=
lambda
x
:
x
[
0
])[
1
]
tgt_len
[
i
]
=
len
(
best_hyp
)
+
1
# +1 for the <EOS> symbol
best
.
append
(
best_hyp
)
# generate target batch
decoded
=
src_len
.
new
(
tgt_len
.
max
().
item
(),
bs
).
fill_
(
self
.
pad_index
)
for
i
,
hypo
in
enumerate
(
best
):
decoded
[:
tgt_len
[
i
]
-
1
,
i
]
=
hypo
decoded
[
tgt_len
[
i
]
-
1
,
i
]
=
self
.
eos_index
# sanity check
assert
(
decoded
==
self
.
eos_index
).
sum
()
==
2
*
bs
return
decoded
,
tgt_len
class
XLMModel
(
XLMPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLMModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
mem_len
=
config
.
mem_len
self
.
reuse_len
=
config
.
reuse_len
self
.
d_model
=
config
.
d_model
self
.
same_length
=
config
.
same_length
self
.
attn_type
=
config
.
attn_type
self
.
bi_data
=
config
.
bi_data
self
.
clamp_len
=
config
.
clamp_len
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
layer
=
XLMLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
layer
]
def
create_mask
(
self
,
qlen
,
mlen
):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
same_length=False: same_length=True:
<mlen > < qlen > <mlen > < qlen >
^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask
=
torch
.
ones
([
qlen
,
qlen
])
mask_up
=
torch
.
triu
(
attn_mask
,
diagonal
=
1
)
attn_mask_pad
=
torch
.
zeros
([
qlen
,
mlen
])
ret
=
torch
.
cat
([
attn_mask_pad
,
mask_up
],
dim
=
1
)
if
self
.
same_length
:
mask_lo
=
torch
.
tril
(
attn_mask
,
diagonal
=-
1
)
ret
=
torch
.
cat
([
ret
[:,
:
qlen
]
+
mask_lo
,
ret
[:,
qlen
:]],
dim
=
1
)
ret
=
ret
.
to
(
next
(
self
.
parameters
()))
return
ret
def
cache_mem
(
self
,
curr_out
,
prev_mem
):
"""cache hidden states into memory."""
if
self
.
mem_len
is
None
or
self
.
mem_len
==
0
:
return
None
else
:
if
self
.
reuse_len
is
not
None
and
self
.
reuse_len
>
0
:
curr_out
=
curr_out
[:
self
.
reuse_len
]
if
prev_mem
is
None
:
new_mem
=
curr_out
[
-
self
.
mem_len
:]
else
:
new_mem
=
torch
.
cat
([
prev_mem
,
curr_out
],
dim
=
0
)[
-
self
.
mem_len
:]
return
new_mem
.
detach
()
@
staticmethod
def
positional_embedding
(
pos_seq
,
inv_freq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
einsum
(
'i,d->id'
,
pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
torch
.
sin
(
sinusoid_inp
),
torch
.
cos
(
sinusoid_inp
)],
dim
=-
1
)
pos_emb
=
pos_emb
[:,
None
,
:]
if
bsz
is
not
None
:
pos_emb
=
pos_emb
.
expand
(
-
1
,
bsz
,
-
1
)
return
pos_emb
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
):
"""create relative positional encoding."""
freq_seq
=
torch
.
arange
(
0
,
self
.
d_model
,
2.0
,
dtype
=
torch
.
float
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
d_model
))
if
self
.
attn_type
==
'bi'
:
# beg, end = klen - 1, -qlen
beg
,
end
=
klen
,
-
qlen
elif
self
.
attn_type
==
'uni'
:
# beg, end = klen - 1, -1
beg
,
end
=
klen
,
-
1
else
:
raise
ValueError
(
'Unknown `attn_type` {}.'
.
format
(
self
.
attn_type
))
if
self
.
bi_data
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
torch
.
float
)
bwd_pos_seq
=
torch
.
arange
(
-
beg
,
-
end
,
1.0
,
dtype
=
torch
.
float
)
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
bwd_pos_seq
=
bwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
if
bsz
is
not
None
:
fwd_pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
//
2
)
bwd_pos_emb
=
self
.
positional_embedding
(
bwd_pos_seq
,
inv_freq
,
bsz
//
2
)
else
:
fwd_pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
)
bwd_pos_emb
=
self
.
positional_embedding
(
bwd_pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
fwd_pos_emb
,
bwd_pos_emb
],
dim
=
1
)
else
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
)
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
)
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
return
pos_emb
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
but with 1 for real tokens and 0 for padding.
Added for easy compatibility with the XLM model (which uses this negative masking).
You can only uses one among `input_mask` and `attention_mask`
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
# the original code for XLM uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
inp_k
=
inp_k
.
transpose
(
0
,
1
).
contiguous
()
token_type_ids
=
token_type_ids
.
transpose
(
0
,
1
).
contiguous
()
if
token_type_ids
is
not
None
else
None
input_mask
=
input_mask
.
transpose
(
0
,
1
).
contiguous
()
if
input_mask
is
not
None
else
None
attention_mask
=
attention_mask
.
transpose
(
0
,
1
).
contiguous
()
if
attention_mask
is
not
None
else
None
perm_mask
=
perm_mask
.
permute
(
1
,
2
,
0
).
contiguous
()
if
perm_mask
is
not
None
else
None
target_mapping
=
target_mapping
.
permute
(
1
,
2
,
0
).
contiguous
()
if
target_mapping
is
not
None
else
None
inp_q
=
inp_q
.
transpose
(
0
,
1
).
contiguous
()
if
inp_q
is
not
None
else
None
qlen
,
bsz
=
inp_k
.
shape
[
0
],
inp_k
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
dtype_float
=
next
(
self
.
parameters
()).
dtype
device
=
next
(
self
.
parameters
()).
device
##### Attention mask
# causal attention mask
if
self
.
attn_type
==
'uni'
:
attn_mask
=
self
.
create_mask
(
qlen
,
mlen
)
attn_mask
=
attn_mask
[:,
:,
None
,
None
]
elif
self
.
attn_type
==
'bi'
:
attn_mask
=
None
else
:
raise
ValueError
(
'Unsupported attention type: {}'
.
format
(
self
.
attn_type
))
# data mask: input mask & perm mask
assert
input_mask
is
None
or
attention_mask
is
None
,
"You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatbility with XLM). Please choose one."
if
input_mask
is
None
and
attention_mask
is
not
None
:
input_mask
=
1.0
-
attention_mask
if
input_mask
is
not
None
and
perm_mask
is
not
None
:
data_mask
=
input_mask
[
None
]
+
perm_mask
elif
input_mask
is
not
None
and
perm_mask
is
None
:
data_mask
=
input_mask
[
None
]
elif
input_mask
is
None
and
perm_mask
is
not
None
:
data_mask
=
perm_mask
else
:
data_mask
=
None
if
data_mask
is
not
None
:
# all mems can be attended to
mems_mask
=
torch
.
zeros
([
data_mask
.
shape
[
0
],
mlen
,
bsz
]).
to
(
data_mask
)
data_mask
=
torch
.
cat
([
mems_mask
,
data_mask
],
dim
=
1
)
if
attn_mask
is
None
:
attn_mask
=
data_mask
[:,
:,
:,
None
]
else
:
attn_mask
+=
data_mask
[:,
:,
:,
None
]
if
attn_mask
is
not
None
:
outputs
=
[
tensor
]
attn_mask
=
(
attn_mask
>
0
).
to
(
dtype_float
)
if
attn_mask
is
not
None
:
non_tgt_mask
=
-
torch
.
eye
(
qlen
).
to
(
attn_mask
)
non_tgt_mask
=
torch
.
cat
([
torch
.
zeros
([
qlen
,
mlen
]).
to
(
attn_mask
),
non_tgt_mask
],
dim
=-
1
)
non_tgt_mask
=
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
).
to
(
attn_mask
)
else
:
non_tgt_mask
=
None
##### Word embeddings and prepare h & g hidden states
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
self
.
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
-
1
)
else
:
inp_q_ext
=
inp_q
[:,
:,
None
]
word_emb_q
=
inp_q_ext
*
self
.
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
output_g
=
self
.
dropout
(
word_emb_q
)
else
:
output_g
=
None
##### Segment embedding
if
token_type_ids
is
not
None
:
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad
=
torch
.
zeros
([
mlen
,
bsz
],
dtype
=
torch
.
long
,
device
=
device
)
cat_ids
=
torch
.
cat
([
mem_pad
,
token_type_ids
],
dim
=
0
)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
(
token_type_ids
[:,
None
]
!=
cat_ids
[
None
,
:]).
long
()
seg_mat
=
F
.
one_hot
(
seg_mat
,
num_classes
=
2
).
to
(
dtype_float
)
else
:
seg_mat
=
None
##### Positional encoding
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
dropout
(
pos_emb
)
##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [n_layer x num_heads]
# and head_mask is converted to shape [n_layer x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
new_mems
=
[]
if
mems
is
None
:
mems
=
[
None
]
*
len
(
self
.
layer
)
hidden_states
=
[]
attentions
=
[]
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
# Save hidden_states
if
output_g
is
None
:
hidden_states
.
append
(
output_h
)
else
:
hidden_states
.
append
((
output_h
,
output_g
))
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
# Save last hidden_state
if
output_g
is
None
:
hidden_states
.
append
(
output_h
)
else
:
hidden_states
.
append
((
output_h
,
output_g
))
# Select the right output and add dropout
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output
=
output
.
permute
(
1
,
0
,
2
).
contiguous
()
if
output_g
is
None
:
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
else
:
hidden_states
=
[
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
]
# Build the list of outputs
outputs
=
[
output
,
new_mems
]
if
self
.
output_attentions
:
outputs
.
append
(
attentions
)
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
outputs
.
append
(
hidden_states
)
outputs
.
append
(
hidden_states
)
if
self
.
output_attentions
:
return
outputs
outputs
.
append
(
attentions
)
return
outputs
# outputs, (hidden_states), (attentions)
class
XLMPredLayer
(
nn
.
Module
):
class
XLMPredLayer
(
nn
.
Module
):
...
@@ -1275,63 +690,59 @@ class XLMPredLayer(nn.Module):
...
@@ -1275,63 +690,59 @@ class XLMPredLayer(nn.Module):
return
self
.
proj
.
log_prob
(
x
)
if
self
.
asm
else
self
.
proj
(
x
)
return
self
.
proj
.
log_prob
(
x
)
if
self
.
asm
else
self
.
proj
(
x
)
class
XLMLMHeadModel
(
XLMPreTrainedModel
):
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
Params:
class
XLMWithLMHeadModel
(
XLMPreTrainedModel
):
`config`: a XLMConfig class instance with the configuration to build a new model
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
Paper: https://arxiv.org/abs/1901.07291
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
Original code: https://github.com/facebookresearch/XLM
This can be used to compute head importance metrics. Default: False
Inputs:
Params:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
`config`: a XLMConfig class instance with the configuration to build a new model
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
attention_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
0 for real tokens and 1 for padding.
This can be used to compute head importance metrics. Default: False
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see XLM paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, d_model],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, d_model],
`pooled_output`: a torch.FloatTensor of size [batch_size, d_model] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see XLM's paper).
Example usage:
Outputs: Tuple of (encoded_layers, pooled_output)
```python
`encoded_layers`: controled by `output_all_encoded_layers` argument:
# Already been converted into WordPiece token ids
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see XLM's paper).
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, d_model=768,
Example usage:
n_layer=12, num_attention_heads=12, intermediate_size=3072)
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
model = modeling.XLMModel(config=config)
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
```
"""
model = modeling.XLMModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLMLMHeadModel
,
self
).
__init__
(
config
)
super
(
XLMLMHeadModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
...
@@ -1341,9 +752,7 @@ class XLMLMHeadModel(XLMPreTrainedModel):
...
@@ -1341,9 +752,7 @@ class XLMLMHeadModel(XLMPreTrainedModel):
self
.
same_length
=
config
.
same_length
self
.
same_length
=
config
.
same_length
self
.
transformer
=
XLMModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
transformer
=
XLMModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
self
.
pred_layer
=
XLMPredLayer
(
config
)
# Tie weights
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
self
.
tie_weights
()
self
.
tie_weights
()
...
@@ -1351,10 +760,9 @@ class XLMLMHeadModel(XLMPreTrainedModel):
...
@@ -1351,10 +760,9 @@ class XLMLMHeadModel(XLMPreTrainedModel):
def
tie_weights
(
self
):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
""" Make sure we are sharing the embeddings
"""
"""
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_
embedding
.
weight
self
.
pred_layer
.
proj
.
weight
=
self
.
transformer
.
embedding
s
.
weight
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
x
,
lengths
,
positions
=
None
,
langs
=
None
,
cache
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
labels
=
None
,
head_mask
=
None
):
labels
=
None
,
head_mask
=
None
):
"""
"""
Args:
Args:
...
@@ -1382,11 +790,10 @@ class XLMLMHeadModel(XLMPreTrainedModel):
...
@@ -1382,11 +790,10 @@ class XLMLMHeadModel(XLMPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
to pool the input to get a vector representation.
"""
"""
transformer_outputs
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
x
,
lengths
,
positions
=
positions
,
langs
=
langs
,
cache
=
cache
,
head_mask
=
head_mask
)
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
lm_loss
(
output
)
logits
=
self
.
pred_layer
(
output
,
labels
)
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
1484d67d
...
@@ -198,7 +198,7 @@ class XLNetConfig(PretrainedConfig):
...
@@ -198,7 +198,7 @@ class XLNetConfig(PretrainedConfig):
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size_or_config_json_file
,
vocab_size_or_config_json_file
=
32000
,
d_model
=
1024
,
d_model
=
1024
,
n_layer
=
24
,
n_layer
=
24
,
n_head
=
16
,
n_head
=
16
,
...
@@ -221,7 +221,12 @@ class XLNetConfig(PretrainedConfig):
...
@@ -221,7 +221,12 @@ class XLNetConfig(PretrainedConfig):
bi_data
=
False
,
bi_data
=
False
,
clamp_len
=-
1
,
clamp_len
=-
1
,
same_length
=
False
,
same_length
=
False
,
finetuning_task
=
None
):
finetuning_task
=
None
,
num_labels
=
2
,
summary_type
=
"last"
,
use_proj
=
True
,
**
kwargs
):
"""Constructs XLNetConfig.
"""Constructs XLNetConfig.
Args:
Args:
...
@@ -265,6 +270,8 @@ class XLNetConfig(PretrainedConfig):
...
@@ -265,6 +270,8 @@ class XLNetConfig(PretrainedConfig):
same_length: bool, whether to use the same attention length for each token.
same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
"""
super
(
XLNetConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
...
@@ -297,7 +304,11 @@ class XLNetConfig(PretrainedConfig):
...
@@ -297,7 +304,11 @@ class XLNetConfig(PretrainedConfig):
self
.
bi_data
=
bi_data
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
self
.
same_length
=
same_length
self
.
finetuning_task
=
finetuning_task
self
.
finetuning_task
=
finetuning_task
self
.
num_labels
=
num_labels
self
.
summary_type
=
summary_type
self
.
use_proj
=
use_proj
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
"or the path to a pretrained model config file (str)"
)
...
@@ -323,9 +334,10 @@ except ImportError:
...
@@ -323,9 +334,10 @@ except ImportError:
return
self
.
weight
*
x
+
self
.
bias
return
self
.
weight
*
x
+
self
.
bias
class
XLNetRelativeAttention
(
nn
.
Module
):
class
XLNetRelativeAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
XLNetRelativeAttention
,
self
).
__init__
()
super
(
XLNetRelativeAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
config
.
output_attentions
if
config
.
d_model
%
config
.
n_head
!=
0
:
if
config
.
d_model
%
config
.
n_head
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"The hidden size (%d) is not a multiple of the number of attention "
...
@@ -533,10 +545,9 @@ class XLNetFeedForward(nn.Module):
...
@@ -533,10 +545,9 @@ class XLNetFeedForward(nn.Module):
return
output
return
output
class
XLNetLayer
(
nn
.
Module
):
class
XLNetLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
):
def
__init__
(
self
,
config
):
super
(
XLNetLayer
,
self
).
__init__
()
super
(
XLNetLayer
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
rel_attn
=
XLNetRelativeAttention
(
config
)
self
.
rel_attn
=
XLNetRelativeAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
ff
=
XLNetFeedForward
(
config
)
self
.
ff
=
XLNetFeedForward
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
...
@@ -562,7 +573,6 @@ class XLNetPreTrainedModel(PreTrainedModel):
...
@@ -562,7 +573,6 @@ class XLNetPreTrainedModel(PreTrainedModel):
"""
"""
config_class
=
XLNetConfig
config_class
=
XLNetConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_xlnet
load_tf_weights
=
load_tf_weights_in_xlnet
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
...
@@ -589,10 +599,10 @@ class XLNetPreTrainedModel(PreTrainedModel):
...
@@ -589,10 +599,10 @@ class XLNetPreTrainedModel(PreTrainedModel):
class
XLNetModel
(
XLNetPreTrainedModel
):
class
XLNetModel
(
XLNetPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
XLNetModel
,
self
).
__init__
(
config
)
super
(
XLNetModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
mem_len
=
config
.
mem_len
self
.
mem_len
=
config
.
mem_len
self
.
reuse_len
=
config
.
reuse_len
self
.
reuse_len
=
config
.
reuse_len
...
@@ -601,25 +611,17 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -601,25 +611,17 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
attn_type
=
config
.
attn_type
self
.
attn_type
=
config
.
attn_type
self
.
bi_data
=
config
.
bi_data
self
.
bi_data
=
config
.
bi_data
self
.
clamp_len
=
config
.
clamp_len
self
.
clamp_len
=
config
.
clamp_len
self
.
n_layer
=
config
.
n_layer
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
)
layer
=
XLNetLayer
(
config
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
logger
.
info
(
"Head pruning is not implemented for XLNet"
)
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
pass
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
layer
]
def
create_mask
(
self
,
qlen
,
mlen
):
def
create_mask
(
self
,
qlen
,
mlen
):
""" create causal attention mask.
""" create causal attention mask.
...
@@ -708,11 +710,11 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -708,11 +710,11 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
return
pos_emb
return
pos_emb
def
forward
(
self
,
inp
_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp
ut_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
head_mask
=
None
):
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp
_k
: int32 Tensor in shape [bsz, len], the input token IDs.
inp
ut_ids
: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
0 for real tokens and 1 for padding.
...
@@ -751,7 +753,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -751,7 +753,7 @@ class XLNetModel(XLNetPreTrainedModel):
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
# so we move here the first dimension (batch) to the end
inp
_k
=
inp_k
.
transpose
(
0
,
1
).
contiguous
()
inp
ut_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
token_type_ids
=
token_type_ids
.
transpose
(
0
,
1
).
contiguous
()
if
token_type_ids
is
not
None
else
None
token_type_ids
=
token_type_ids
.
transpose
(
0
,
1
).
contiguous
()
if
token_type_ids
is
not
None
else
None
input_mask
=
input_mask
.
transpose
(
0
,
1
).
contiguous
()
if
input_mask
is
not
None
else
None
input_mask
=
input_mask
.
transpose
(
0
,
1
).
contiguous
()
if
input_mask
is
not
None
else
None
attention_mask
=
attention_mask
.
transpose
(
0
,
1
).
contiguous
()
if
attention_mask
is
not
None
else
None
attention_mask
=
attention_mask
.
transpose
(
0
,
1
).
contiguous
()
if
attention_mask
is
not
None
else
None
...
@@ -759,7 +761,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -759,7 +761,7 @@ class XLNetModel(XLNetPreTrainedModel):
target_mapping
=
target_mapping
.
permute
(
1
,
2
,
0
).
contiguous
()
if
target_mapping
is
not
None
else
None
target_mapping
=
target_mapping
.
permute
(
1
,
2
,
0
).
contiguous
()
if
target_mapping
is
not
None
else
None
inp_q
=
inp_q
.
transpose
(
0
,
1
).
contiguous
()
if
inp_q
is
not
None
else
None
inp_q
=
inp_q
.
transpose
(
0
,
1
).
contiguous
()
if
inp_q
is
not
None
else
None
qlen
,
bsz
=
inp
_k
.
shape
[
0
],
inp
_k
.
shape
[
1
]
qlen
,
bsz
=
inp
ut_ids
.
shape
[
0
],
inp
ut_ids
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
...
@@ -810,7 +812,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -810,7 +812,7 @@ class XLNetModel(XLNetPreTrainedModel):
non_tgt_mask
=
None
non_tgt_mask
=
None
##### Word embeddings and prepare h & g hidden states
##### Word embeddings and prepare h & g hidden states
word_emb_k
=
self
.
word_embedding
(
inp
_k
)
word_emb_k
=
self
.
word_embedding
(
inp
ut_ids
)
output_h
=
self
.
dropout
(
word_emb_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
if
target_mapping
is
not
None
:
...
@@ -838,20 +840,20 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -838,20 +840,20 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
dropout
(
pos_emb
)
pos_emb
=
self
.
dropout
(
pos_emb
)
#
#### H
ead mask if needed
(for bertology/pruning)
#
Prepare h
ead mask if needed
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [n_layer x num_heads]
# input head_mask has shape [num_heads] or [n
um_hidden
_layer
s
x num_heads]
(a head_mask for each layer)
# and head_mask is converted to shape [n_layer x
batch x num_heads x seq_length x seq_length
]
# and head_mask is converted to shape [n
um_hidden
_layer
s
x
qlen x klen x bsz x n_head
]
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
unsqueeze
(
1
)
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
head_mask
=
[
None
]
*
self
.
n_layer
new_mems
=
[]
new_mems
=
[]
if
mems
is
None
:
if
mems
is
None
:
...
@@ -870,7 +872,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -870,7 +872,7 @@ class XLNetModel(XLNetPreTrainedModel):
head_mask
=
head_mask
[
i
])
head_mask
=
head_mask
[
i
])
output_h
,
output_g
=
outputs
[:
2
]
output_h
,
output_g
=
outputs
[:
2
]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
.
append
(
outputs
[
2
:
])
attentions
.
append
(
outputs
[
2
])
# Add last hidden state
# Add last hidden state
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
...
@@ -887,6 +889,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -887,6 +889,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
outputs
.
append
(
hidden_states
)
outputs
.
append
(
hidden_states
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
outputs
.
append
(
attentions
)
return
outputs
# outputs, new_mems, (hidden_states), (attentions)
return
outputs
# outputs, new_mems, (hidden_states), (attentions)
...
@@ -902,7 +905,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -902,7 +905,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
This can be used to compute head importance metrics. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
inp
_k
: int32 Tensor in shape [bsz, len], the input token IDs.
inp
ut_ids
: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
0 for real tokens and 1 for padding.
...
@@ -953,16 +956,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -953,16 +956,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
XLNetLMHeadModel
,
self
).
__init__
(
config
)
super
(
XLNetLMHeadModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
attn_type
=
config
.
attn_type
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
same_length
=
config
.
same_length
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
self
.
transformer
=
XLNetModel
(
config
)
output_hidden_states
=
output_hidden_states
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
# Tie weights
# Tie weights
...
@@ -975,12 +974,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -975,12 +974,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
"""
"""
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
def
forward
(
self
,
inp
_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp
ut_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
labels
=
None
,
head_mask
=
None
):
labels
=
None
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp
_k
: int32 Tensor in shape [bsz, len], the input token IDs.
inp
ut_ids
: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
0 for real tokens and 1 for padding.
...
@@ -1008,7 +1007,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1008,7 +1007,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
to pool the input to get a vector representation.
"""
"""
transformer_outputs
=
self
.
transformer
(
inp
_k
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
inp
ut_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
...
@@ -1025,14 +1024,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1025,14 +1024,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
class
XLNetSequenceSummary
(
nn
.
Module
):
class
XLNetSequenceSummary
(
nn
.
Module
):
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
):
def
__init__
(
self
,
config
):
super
(
XLNetSequenceSummary
,
self
).
__init__
()
super
(
XLNetSequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
summary_type
self
.
summary_type
=
config
.
summary_type
if
use_proj
:
if
config
.
use_proj
:
self
.
summary
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_model
)
self
.
summary
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_model
)
else
:
else
:
self
.
summary
=
None
self
.
summary
=
None
if
summary_type
==
'attn'
:
if
config
.
summary_type
==
'attn'
:
# We should use a standard multi-head attention module with absolute positional embedding for that.
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
...
@@ -1069,7 +1068,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1069,7 +1068,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
to pool the input to get a vector representation. Default: last
to pool the input to get a vector representation. Default: last
Inputs:
Inputs:
inp
_k
: int32 Tensor in shape [bsz, len], the input token IDs.
inp
ut_ids
: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
0 for real tokens and 1 for padding.
...
@@ -1121,30 +1120,21 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1121,30 +1120,21 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
,
num_labels
=
2
,
def
__init__
(
self
,
config
):
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
attn_type
=
config
.
attn_type
self
.
transformer
=
XLNetModel
(
config
)
self
.
same_length
=
config
.
same_length
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
)
self
.
summary_type
=
summary_type
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
num_labels
=
num_labels
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
,
summary_type
=
summary_type
,
use_proj
=
use_proj
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
num_labels
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
inp
_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp
ut_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
labels
=
None
,
head_mask
=
None
):
labels
=
None
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp
_k
: int32 Tensor in shape [bsz, len], the input token IDs.
inp
ut_ids
: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
0 for real tokens and 1 for padding.
...
@@ -1169,7 +1159,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1169,7 +1159,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Only used during pretraining for two-stream attention.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Set to None during finetuning.
"""
"""
transformer_outputs
=
self
.
transformer
(
inp
_k
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
inp
ut_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
...
@@ -1247,20 +1237,18 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
...
@@ -1247,20 +1237,18 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
def
__init__
(
self
,
config
):
super
(
XLNetForQuestionAnswering
,
self
).
__init__
(
config
)
super
(
XLNetForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
self
.
transformer
=
XLNetModel
(
config
)
output_hidden_states
=
output_hidden_state
s
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_label
s
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
inp
_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp
ut_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
inp
_k
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
inp
ut_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
logits
=
self
.
qa_outputs
(
transformer_outputs
[
0
])
logits
=
self
.
qa_outputs
(
transformer_outputs
[
0
])
...
...
pytorch_pretrained_bert/tests/__init__.py
0 → 100644
View file @
1484d67d
tests/conftest.py
→
pytorch_pretrained_bert/
tests/conftest.py
View file @
1484d67d
File moved
sampl
es/input.txt
→
pytorch_pretrained_bert/tests/fixtur
es/input.txt
View file @
1484d67d
File moved
sampl
es/sample_text.txt
→
pytorch_pretrained_bert/tests/fixtur
es/sample_text.txt
View file @
1484d67d
File moved
sampl
es/test_sentencepiece.model
→
pytorch_pretrained_bert/tests/fixtur
es/test_sentencepiece.model
View file @
1484d67d
File moved
pytorch_pretrained_bert/tests/model_tests_commons.py
0 → 100644
View file @
1484d67d
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
shutil
import
json
import
random
import
torch
def
create_and_check_for_headmasking
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
config
.
output_hidden_states
=
True
model
=
model_class
(
config
=
config
)
model
.
eval
()
head_mask
=
torch
.
zeros
(
tester
.
num_hidden_layers
,
tester
.
num_attention_heads
)
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask
.
requires_grad_
(
requires_grad
=
True
)
outputs
=
model
(
**
inputs_dict
,
head_mask
=
head_mask
)
# Compute some gradients
output
=
sum
(
t
.
sum
()
for
t
in
outputs
[
0
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
head_mask
.
grad
tester
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
tester
.
num_hidden_layers
)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def
create_and_check_for_head_pruning
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
model
=
model_class
(
config
=
config
)
model
.
eval
()
heads_to_prune
=
{
0
:
list
(
range
(
1
,
tester
.
num_attention_heads
)),
-
1
:
[
0
]}
model
.
prune_heads
(
heads_to_prune
)
outputs
=
model
(
**
inputs_dict
)
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = bert_model.get_multihead_outputs()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def
create_and_check_for_attentions
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
model
=
model_class
(
config
)
model
.
eval
()
outputs
=
model
(
**
inputs_dict
)
attentions
=
outputs
[
-
1
]
tester
.
parent
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
tester
.
parent
.
assertEqual
(
model
.
config
.
output_hidden_states
,
False
)
tester
.
parent
.
assertEqual
(
len
(
attentions
),
tester
.
num_hidden_layers
)
tester
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
shape
[
-
3
:]),
[
tester
.
num_attention_heads
,
tester
.
seq_length
,
tester
.
key_len
if
hasattr
(
tester
,
'key_len'
)
else
tester
.
seq_length
])
out_len
=
len
(
outputs
)
# Check attention is always last and order is fine
config
.
output_attentions
=
True
config
.
output_hidden_states
=
True
model
=
model_class
(
config
)
model
.
eval
()
outputs
=
model
(
**
inputs_dict
)
tester
.
parent
.
assertEqual
(
out_len
+
1
,
len
(
outputs
))
tester
.
parent
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
tester
.
parent
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
attentions
=
outputs
[
-
1
]
tester
.
parent
.
assertEqual
(
len
(
attentions
),
tester
.
num_hidden_layers
)
tester
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
shape
[
-
3
:]),
[
tester
.
num_attention_heads
,
tester
.
seq_length
,
tester
.
key_len
if
hasattr
(
tester
,
'key_len'
)
else
tester
.
seq_length
])
def
create_and_check_for_hidden_states
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
config
.
output_hidden_states
=
True
config
.
output_attentions
=
False
model
=
model_class
(
config
)
model
.
eval
()
outputs
=
model
(
**
inputs_dict
)
hidden_states
=
outputs
[
-
1
]
tester
.
parent
.
assertEqual
(
model
.
config
.
output_attentions
,
False
)
tester
.
parent
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
tester
.
parent
.
assertEqual
(
len
(
hidden_states
),
tester
.
num_hidden_layers
+
1
)
tester
.
parent
.
assertListEqual
(
list
(
hidden_states
[
0
].
shape
[
-
2
:]),
[
tester
.
seq_length
,
tester
.
hidden_size
])
def
create_and_check_commons
(
tester
,
config
,
inputs_dict
):
create_and_check_for_attentions
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_headmasking
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_head_pruning
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_hidden_states
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
class
ConfigTester
(
object
):
def
__init__
(
self
,
parent
,
config_class
=
None
,
**
kwargs
):
self
.
parent
=
parent
self
.
config_class
=
config_class
self
.
inputs_dict
=
kwargs
def
create_and_test_config_to_json_string
(
self
):
config
=
self
.
config_class
(
**
self
.
inputs_dict
)
obj
=
json
.
loads
(
config
.
to_json_string
())
for
key
,
value
in
self
.
inputs_dict
.
items
():
self
.
parent
.
assertEqual
(
obj
[
key
],
value
)
def
create_and_test_config_to_json_file
(
self
):
config_first
=
self
.
config_class
(
**
self
.
inputs_dict
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
self
.
config_class
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
parent
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
def
run_common_tests
(
self
):
self
.
create_and_test_config_to_json_string
()
self
.
create_and_test_config_to_json_file
()
class
GPTModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_position_ids
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
n_special
=
1
,
n_positions
=
33
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
n_choices
=
3
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
scope
=
None
,
config_class
=
None
,
base_model_class
=
None
,
lm_head_model_class
=
None
,
double_head_model_class
=
None
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_position_ids
=
use_position_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
n_positions
=
n_positions
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
n_choices
=
n_choices
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
scope
=
scope
self
.
config_class
=
config_class
self
.
base_model_class
=
base_model_class
self
.
lm_head_model_class
=
lm_head_model_class
self
.
double_head_model_class
=
double_head_model_class
self
.
all_model_classes
=
(
base_model_class
,
lm_head_model_class
,
double_head_model_class
)
def
prepare_config_and_inputs
(
self
):
total_num_tokens
=
self
.
vocab_size
+
self
.
n_special
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_num_tokens
)
position_ids
=
None
if
self
.
use_position_ids
:
position_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
n_positions
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
total_voc
=
self
.
vocab_size
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_voc
)
mc_labels
=
None
lm_labels
=
None
mc_token_ids
=
None
if
self
.
use_labels
:
mc_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
num_labels
)
mc_token_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
],
self
.
seq_length
)
config
=
self
.
config_class
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_special
=
self
.
n_special
,
n_positions
=
self
.
n_positions
,
n_embd
=
self
.
hidden_size
,
n_layer
=
self
.
num_hidden_layers
,
n_head
=
self
.
num_attention_heads
,
initializer_range
=
self
.
initializer_range
)
return
(
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
)
def
create_and_check_base_model
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
self
.
base_model_class
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
hidden_state
=
outputs
[
0
]
self
.
parent
.
assertListEqual
(
list
(
hidden_state
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
self
.
lm_head_model_class
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
position_ids
,
token_type_ids
,
lm_labels
)
loss
,
lm_logits
=
outputs
[:
2
]
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
lm_logits
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
self
.
parent
.
assertListEqual
(
list
(
loss
.
size
()),
[])
def
create_and_check_presents
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
)
presents
=
outputs
[
-
1
]
self
.
parent
.
assertEqual
(
self
.
num_hidden_layers
,
len
(
presents
))
self
.
parent
.
assertListEqual
(
list
(
presents
[
0
].
size
()),
[
2
,
self
.
batch_size
*
self
.
n_choices
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
hidden_size
//
self
.
num_attention_heads
])
def
create_and_check_double_heads
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
self
.
double_head_model_class
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
mc_token_ids
,
lm_labels
=
lm_labels
,
mc_labels
=
mc_labels
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
lm_loss
,
mc_loss
,
lm_logits
,
mc_logits
=
outputs
[:
4
]
loss
=
[
lm_loss
,
mc_loss
]
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
lm_logits
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
self
.
parent
.
assertListEqual
(
list
(
mc_logits
.
size
()),
[
self
.
batch_size
,
self
.
n_choices
])
self
.
parent
.
assertListEqual
(
[
list
(
l
.
size
())
for
l
in
loss
],
[[],
[]])
def
create_and_check_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
self
.
base_model_class
.
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
self
.
base_model_class
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
parent
.
assertIsNotNone
(
model
)
def
create_and_check_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
inputs_dict
=
{
'input_ids'
:
input_ids
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
def
run_common_tests
(
self
,
test_presents
=
False
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_base_model
(
*
config_and_inputs
)
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_lm_head
(
*
config_and_inputs
)
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_double_heads
(
*
config_and_inputs
)
if
test_presents
:
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_presents
(
*
config_and_inputs
)
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_commons
(
*
config_and_inputs
)
def
run_slow_tests
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
self
.
create_and_check_model_from_pretrained
(
*
config_and_inputs
)
pytorch_pretrained_bert/tests/model_utils_test.py
0 → 100644
View file @
1484d67d
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
PretrainedConfig
,
PreTrainedModel
from
pytorch_pretrained_bert.modeling
import
BertModel
,
BertConfig
,
PRETRAINED_MODEL_ARCHIVE_MAP
,
PRETRAINED_CONFIG_ARCHIVE_MAP
class
ModelUtilsTest
(
unittest
.
TestCase
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
config
=
BertConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
PretrainedConfig
)
model
=
BertModel
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
PreTrainedModel
)
config
=
BertConfig
.
from_pretrained
(
model_name
,
output_attentions
=
True
,
output_hidden_states
=
True
)
model
=
BertModel
.
from_pretrained
(
model_name
,
output_attentions
=
True
,
output_hidden_states
=
True
)
self
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
self
.
assertEqual
(
model
.
config
,
config
)
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
0 → 100644
View file @
1484d67d
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
from
.model_tests_commons
import
(
create_and_check_for_attentions
,
create_and_check_for_head_pruning
,
create_and_check_for_headmasking
,
create_and_check_for_hidden_states
,
ConfigTester
,
GPTModelTester
)
class
GPT2ModelTest
(
unittest
.
TestCase
):
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
GPT2Config
,
n_embd
=
37
)
config_tester
.
run_common_tests
()
def
test_model
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_common_tests
(
test_presents
=
True
)
@
pytest
.
mark
.
slow
def
test_pretrained
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_slow_tests
()
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_pretrained_bert/tests/modeling_openai_test.py
0 → 100644
View file @
1484d67d
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
.model_tests_commons
import
(
create_and_check_for_attentions
,
create_and_check_for_head_pruning
,
create_and_check_for_headmasking
,
create_and_check_for_hidden_states
,
ConfigTester
,
GPTModelTester
)
class
OpenAIModelTest
(
unittest
.
TestCase
):
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
OpenAIGPTConfig
,
n_embd
=
37
)
config_tester
.
run_common_tests
()
def
test_model
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
OpenAIGPTConfig
,
base_model_class
=
OpenAIGPTModel
,
lm_head_model_class
=
OpenAIGPTLMHeadModel
,
double_head_model_class
=
OpenAIGPTDoubleHeadsModel
)
model_tester
.
run_common_tests
(
test_presents
=
False
)
@
pytest
.
mark
.
slow
def
test_pretrained
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
OpenAIGPTConfig
,
base_model_class
=
OpenAIGPTModel
,
lm_head_model_class
=
OpenAIGPTLMHeadModel
,
double_head_model_class
=
OpenAIGPTDoubleHeadsModel
)
model_tester
.
run_slow_tests
()
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/modeling_test.py
→
pytorch_pretrained_bert/
tests/modeling_test.py
View file @
1484d67d
...
@@ -31,6 +31,8 @@ from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
...
@@ -31,6 +31,8 @@ from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForTokenClassification
,
BertForMultipleChoice
)
BertForTokenClassification
,
BertForMultipleChoice
)
from
pytorch_pretrained_bert.modeling
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_pretrained_bert.modeling
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
ids_tensor
)
class
BertModelTest
(
unittest
.
TestCase
):
class
BertModelTest
(
unittest
.
TestCase
):
class
BertModelTester
(
object
):
class
BertModelTester
(
object
):
...
@@ -57,7 +59,11 @@ class BertModelTest(unittest.TestCase):
...
@@ -57,7 +59,11 @@ class BertModelTest(unittest.TestCase):
initializer_range
=
0.02
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_labels
=
3
,
num_choices
=
4
,
num_choices
=
4
,
scope
=
None
):
scope
=
None
,
all_model_classes
=
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
),
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
seq_length
=
seq_length
...
@@ -80,25 +86,26 @@ class BertModelTest(unittest.TestCase):
...
@@ -80,25 +86,26 @@ class BertModelTest(unittest.TestCase):
self
.
num_labels
=
num_labels
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
scope
=
scope
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
token_type_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
sequence_labels
=
None
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
choice_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
sequence_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
BertConfig
(
config
=
BertConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
@@ -120,136 +127,117 @@ class BertModelTest(unittest.TestCase):
...
@@ -120,136 +127,117 @@ class BertModelTest(unittest.TestCase):
list
(
result
[
"loss"
].
size
()),
list
(
result
[
"loss"
].
size
()),
[])
[])
def
create_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_
and_check_
bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertModel
(
config
=
config
)
model
=
BertModel
(
config
=
config
)
model
.
eval
()
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
model
=
BertModel
(
config
=
config
,
output_hidden_states
=
True
)
result
=
{
model
.
eval
()
_
,
_
,
all_encoder_layers
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"sequence_output"
:
sequence_output
,
"sequence_output"
:
sequence_output
,
"pooled_output"
:
pooled_output
,
"pooled_output"
:
pooled_output
,
"all_encoder_layers"
:
all_encoder_layers
,
}
}
return
outputs
def
check_bert_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
[
size
for
layer
in
result
[
"all_encoder_layers"
]
for
size
in
layer
.
size
()],
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
*
(
self
.
num_hidden_layers
+
1
))
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_
and_check_
bert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMaskedLM
(
config
=
config
)
model
=
BertForMaskedLM
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
loss
,
prediction_scores
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"prediction_scores"
:
prediction_scores
,
}
}
return
outputs
def
check_bert_for_masked_lm_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
check_loss_output
(
result
)
def
create_bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_
and_check_
bert_for_next_sequence_prediction
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
=
BertForNextSentencePrediction
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
loss
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"seq_relationship_score"
:
seq_relationship_score
,
"seq_relationship_score"
:
seq_relationship_score
,
}
}
return
outputs
def
check_bert_for_next_sequence_prediction_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_score"
].
size
()),
list
(
result
[
"seq_relationship_score"
].
size
()),
[
self
.
batch_size
,
2
])
[
self
.
batch_size
,
2
])
self
.
check_loss_output
(
result
)
def
create_bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_
and_check_
bert_for_pretraining
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForPreTraining
(
config
=
config
)
model
=
BertForPreTraining
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
,
prediction_scores
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
loss
,
prediction_scores
,
seq_relationship_score
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
,
sequence_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"prediction_scores"
:
prediction_scores
,
"prediction_scores"
:
prediction_scores
,
"seq_relationship_score"
:
seq_relationship_score
,
"seq_relationship_score"
:
seq_relationship_score
,
}
}
return
outputs
def
check_bert_for_pretraining_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"seq_relationship_score"
].
size
()),
list
(
result
[
"seq_relationship_score"
].
size
()),
[
self
.
batch_size
,
2
])
[
self
.
batch_size
,
2
])
self
.
check_loss_output
(
result
)
def
create_bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_
and_check_
bert_for_question_answering
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForQuestionAnswering
(
config
=
config
)
model
=
BertForQuestionAnswering
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
,
start_logits
,
end_logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
loss
,
start_logits
,
end_logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
sequence_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"start_logits"
:
start_logits
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
"end_logits"
:
end_logits
,
}
}
return
outputs
def
check_bert_for_question_answering_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
def
create_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_bert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForSequenceClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
config
.
num_labels
=
self
.
num_labels
model
=
BertForSequenceClassification
(
config
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
loss
,
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"logits"
:
logits
,
"logits"
:
logits
,
}
}
return
outputs
def
check_bert_for_sequence_classification_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_labels
])
[
self
.
batch_size
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
def
create_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_bert_for_token_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForTokenClassification
(
config
=
config
,
num_labels
=
self
.
num_labels
)
config
.
num_labels
=
self
.
num_labels
model
=
BertForTokenClassification
(
config
=
config
)
model
.
eval
()
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
loss
,
logits
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
token_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"logits"
:
logits
,
"logits"
:
logits
,
}
}
return
outputs
def
check_bert_for_token_classification_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
])
self
.
check_loss_output
(
result
)
def
create_bert_for_multiple_choice
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_bert_for_multiple_choice
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
BertForMultipleChoice
(
config
=
config
,
num_choices
=
self
.
num_choices
)
config
.
num_choices
=
self
.
num_choices
model
=
BertForMultipleChoice
(
config
=
config
)
model
.
eval
()
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
...
@@ -258,148 +246,26 @@ class BertModelTest(unittest.TestCase):
...
@@ -258,148 +246,26 @@ class BertModelTest(unittest.TestCase):
multiple_choice_token_type_ids
,
multiple_choice_token_type_ids
,
multiple_choice_input_mask
,
multiple_choice_input_mask
,
choice_labels
)
choice_labels
)
outputs
=
{
result
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"logits"
:
logits
,
"logits"
:
logits
,
}
}
return
outputs
def
check_bert_for_multiple_choice
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
[
self
.
batch_size
,
self
.
num_choices
])
self
.
check_loss_output
(
result
)
def
create_and_check_bert_for_attentions
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_bert_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
create_and_check_commons
(
self
,
config
,
inputs_dict
)
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
,
output_attentions
=
True
)
else
:
model
=
model_class
(
config
=
config
,
output_attentions
=
True
)
model
.
eval
()
outputs
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
attentions
=
outputs
[
-
1
]
self
.
parent
.
assertEqual
(
len
(
attentions
),
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
attentions
[
0
].
size
()),
[
self
.
batch_size
,
self
.
num_attention_heads
,
self
.
seq_length
,
self
.
seq_length
])
def
create_and_check_bert_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
)
else
:
model
=
model_class
(
config
=
config
)
model
.
eval
()
head_mask
=
torch
.
ones
(
self
.
num_hidden_layers
,
self
.
num_attention_heads
).
to
(
input_ids
.
device
)
head_mask
[
0
,
1
:
-
1
]
=
0.0
# Mask all but the first and last heads on the first layer
head_mask
[
-
1
,
1
:]
=
0.0
# Mask all but the first head on the last layer
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask
.
requires_grad_
(
requires_grad
=
True
)
outputs
=
model
(
input_ids
,
token_type_ids
,
input_mask
,
head_mask
=
head_mask
)
# Compute some gradients
output
=
sum
(
t
.
sum
()
for
t
in
outputs
[
0
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
head_mask
.
grad
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
num_hidden_layers
)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def
create_and_check_bert_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
for
model_class
in
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
):
if
model_class
in
[
BertForSequenceClassification
,
BertForTokenClassification
]:
model
=
model_class
(
config
=
config
,
num_labels
=
self
.
num_labels
)
else
:
model
=
model_class
(
config
=
config
)
model
.
eval
()
bert_model
=
model
if
isinstance
(
model
,
BertModel
)
else
model
.
bert
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
num_attention_heads
)),
-
1
:
[
0
]}
bert_model
.
prune_heads
(
heads_to_prune
)
outputs
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = bert_model.get_multihead_outputs()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
def
test_config_to_json_string
(
self
):
def
test_config
(
self
):
config
=
BertConfig
(
vocab_size_or_config_json_file
=
99
,
hidden_size
=
37
)
config_tester
=
ConfigTester
(
self
,
config_class
=
BertConfig
,
hidden_size
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
config_tester
.
run_common_tests
()
self
.
assertEqual
(
obj
[
"vocab_size"
],
99
)
self
.
assertEqual
(
obj
[
"hidden_size"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
BertConfig
(
vocab_size_or_config_json_file
=
99
,
hidden_size
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
BertConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
...
@@ -411,57 +277,31 @@ class BertModelTest(unittest.TestCase):
...
@@ -411,57 +277,31 @@ class BertModelTest(unittest.TestCase):
def
run_tester
(
self
,
tester
):
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_bert_model
(
*
config_and_inputs
)
tester
.
create_and_check_bert_model
(
*
config_and_inputs
)
tester
.
check_bert_model_output
(
output_result
)
output_result
=
tester
.
create_bert_for_masked_lm
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
check_bert_for_masked_lm_output
(
output_result
)
tester
.
create_and_check_bert_for_masked_lm
(
*
config_and_inputs
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
tester
.
check_bert_for_next_sequence_prediction_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_pretraining
(
*
config_and_inputs
)
tester
.
check_bert_for_pretraining_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_question_answering
(
*
config_and_inputs
)
tester
.
check_bert_for_question_answering_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_sequence_classification
(
*
config_and_inputs
)
tester
.
check_bert_for_sequence_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_token_classification
(
*
config_and_inputs
)
tester
.
check_bert_for_token_classification_output
(
output_result
)
tester
.
check_loss_output
(
output_result
)
output_result
=
tester
.
create_bert_for_multiple_choice
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
check_bert_for_multiple_choice
(
output_result
)
tester
.
create_and_check_bert_for_multiple_choice
(
*
config_and_inputs
)
tester
.
check_loss_output
(
output_result
)
tester
.
create_and_check_bert_for_attentions
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
tester
.
create_and_check_bert_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
tester
.
create_and_check_bert_for_pretraining
(
*
config_and_inputs
)
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
for
dim
in
shape
:
tester
.
create_and_check_bert_for_question_answering
(
*
config_and_inputs
)
total_dims
*=
dim
values
=
[]
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
for
_
in
range
(
total_dims
):
tester
.
create_and_check_bert_for_sequence_classification
(
*
config_and_inputs
)
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_token_classification
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_commons
(
*
config_and_inputs
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/modeling_transfo_xl_test.py
→
pytorch_pretrained_bert/
tests/modeling_transfo_xl_test.py
View file @
1484d67d
...
@@ -28,6 +28,8 @@ import torch
...
@@ -28,6 +28,8 @@ import torch
from
pytorch_pretrained_bert
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
)
from
pytorch_pretrained_bert
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
)
from
pytorch_pretrained_bert.modeling_transfo_xl
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_pretrained_bert.modeling_transfo_xl
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
.model_tests_commons
import
ConfigTester
,
create_and_check_commons
,
ids_tensor
class
TransfoXLModelTest
(
unittest
.
TestCase
):
class
TransfoXLModelTest
(
unittest
.
TestCase
):
class
TransfoXLModelTester
(
object
):
class
TransfoXLModelTester
(
object
):
...
@@ -41,54 +43,58 @@ class TransfoXLModelTest(unittest.TestCase):
...
@@ -41,54 +43,58 @@ class TransfoXLModelTest(unittest.TestCase):
use_labels
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
cutoffs
=
[
10
,
50
,
80
],
d_model
=
32
,
hidden_size
=
32
,
d_embed
=
32
,
d_embed
=
32
,
n_head
=
4
,
num_attentio
n_head
s
=
4
,
d_head
=
8
,
d_head
=
8
,
d_inner
=
128
,
d_inner
=
128
,
div_val
=
2
,
div_val
=
2
,
n_layer
=
5
,
num_hidde
n_layer
s
=
5
,
scope
=
None
,
scope
=
None
,
seed
=
1
):
seed
=
1
,
all_model_classes
=
(
TransfoXLModel
,
TransfoXLLMHeadModel
),
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
mem_len
=
mem_len
self
.
key_len
=
seq_length
+
mem_len
self
.
clamp_len
=
clamp_len
self
.
clamp_len
=
clamp_len
self
.
is_training
=
is_training
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
cutoffs
=
cutoffs
self
.
d_model
=
d_model
self
.
hidden_size
=
hidden_size
self
.
d_embed
=
d_embed
self
.
d_embed
=
d_embed
self
.
n_head
=
n_head
self
.
n
um_attention
_head
s
=
num_attentio
n_head
s
self
.
d_head
=
d_head
self
.
d_head
=
d_head
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
div_val
=
div_val
self
.
div_val
=
div_val
self
.
n_layer
=
n_layer
self
.
n
um_hidden
_layer
s
=
num_hidde
n_layer
s
self
.
scope
=
scope
self
.
scope
=
scope
self
.
seed
=
seed
self
.
seed
=
seed
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
TransfoXLModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
TransfoXLModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
None
lm_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
lm_labels
=
TransfoXLModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
TransfoXLConfig
(
config
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
mem_len
=
self
.
mem_len
,
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
clamp_len
=
self
.
clamp_len
,
cutoffs
=
self
.
cutoffs
,
cutoffs
=
self
.
cutoffs
,
d_model
=
self
.
d_model
,
d_model
=
self
.
hidden_size
,
d_embed
=
self
.
d_embed
,
d_embed
=
self
.
d_embed
,
n_head
=
self
.
n_head
,
n_head
=
self
.
n
um_attention
_head
s
,
d_head
=
self
.
d_head
,
d_head
=
self
.
d_head
,
d_inner
=
self
.
d_inner
,
d_inner
=
self
.
d_inner
,
div_val
=
self
.
div_val
,
div_val
=
self
.
div_val
,
n_layer
=
self
.
n_layer
)
n_layer
=
self
.
n
um_hidden
_layer
s
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
...
@@ -113,37 +119,34 @@ class TransfoXLModelTest(unittest.TestCase):
...
@@ -113,37 +119,34 @@ class TransfoXLModelTest(unittest.TestCase):
def
check_transfo_xl_model_output
(
self
,
result
):
def
check_transfo_xl_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
size
()),
list
(
result
[
"hidden_states_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
d_model
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
size
()),
list
(
result
[
"hidden_states_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
d_model
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
n
um_hidden
_layer
s
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
n
um_hidden
_layer
s
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TransfoXLLMHeadModel
(
config
)
model
=
TransfoXLLMHeadModel
(
config
)
model
.
eval
()
model
.
eval
()
loss_1
,
mems_1a
=
model
(
input_ids_1
,
labels
=
lm_labels
)
lm_logits_1
,
mems_1
=
model
(
input_ids_1
)
lm_logits_1
,
mems_1b
=
model
(
input_ids_1
)
loss_1
,
_
,
mems_1
=
model
(
input_ids_1
,
labels
=
lm_labels
)
lm_logits_2
,
mems_2
=
model
(
input_ids_2
,
mems
=
mems_1
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
labels
=
lm_labels
,
mems
=
mems_1a
)
loss_2
,
_
,
mems_2
=
model
(
input_ids_2
,
labels
=
lm_labels
,
mems
=
mems_1
)
lm_logits_2
,
mems_2b
=
model
(
input_ids_2
,
mems
=
mems_1b
)
outputs
=
{
outputs
=
{
"loss_1"
:
loss_1
,
"loss_1"
:
loss_1
,
"mems_1
a
"
:
mems_1
a
,
"mems_1"
:
mems_1
,
"lm_logits_1"
:
lm_logits_1
,
"lm_logits_1"
:
lm_logits_1
,
"mems_1b"
:
mems_1b
,
"loss_2"
:
loss_2
,
"loss_2"
:
loss_2
,
"mems_2
a
"
:
mems_2
a
,
"mems_2"
:
mems_2
,
"lm_logits_2"
:
lm_logits_2
,
"lm_logits_2"
:
lm_logits_2
,
"mems_2b"
:
mems_2b
,
}
}
return
outputs
return
outputs
...
@@ -155,14 +158,8 @@ class TransfoXLModelTest(unittest.TestCase):
...
@@ -155,14 +158,8 @@ class TransfoXLModelTest(unittest.TestCase):
list
(
result
[
"lm_logits_1"
].
size
()),
list
(
result
[
"lm_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1b"
]))
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
list
(
result
[
"loss_2"
].
size
()),
...
@@ -171,31 +168,19 @@ class TransfoXLModelTest(unittest.TestCase):
...
@@ -171,31 +168,19 @@ class TransfoXLModelTest(unittest.TestCase):
list
(
result
[
"lm_logits_2"
].
size
()),
list
(
result
[
"lm_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2b"
]),
def
create_and_check_transfo_xl_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
self
.
parent
.
assertListEqual
(
create_and_check_commons
(
self
,
config
,
inputs_dict
)
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2b"
]))
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
TransfoXLModelTest
.
TransfoXLModelTester
(
self
))
self
.
run_tester
(
TransfoXLModelTest
.
TransfoXLModelTester
(
self
))
def
test_config_to_json_string
(
self
):
def
test_config
(
self
):
config
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
96
,
d_embed
=
37
)
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
config_tester
.
run_common_tests
()
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_embed"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
96
,
d_embed
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
TransfoXLConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
...
@@ -209,28 +194,18 @@ class TransfoXLModelTest(unittest.TestCase):
...
@@ -209,28 +194,18 @@ class TransfoXLModelTest(unittest.TestCase):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
set_seed
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_model
(
*
config_and_inputs
)
output_result
=
tester
.
create_transfo_xl_model
(
*
config_and_inputs
)
tester
.
check_transfo_xl_model_output
(
output_result
)
tester
.
check_transfo_xl_model_output
(
output_result
)
tester
.
set_seed
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
@
classmethod
tester
.
set_seed
()
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
"""Creates a random int32 tensor of the shape within the vocab size."""
tester
.
create_and_check_transfo_xl_commons
(
*
config_and_inputs
)
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/modeling_xlnet_test.py
→
pytorch_pretrained_bert/
tests/modeling_xlnet_test.py
View file @
1484d67d
...
@@ -25,9 +25,11 @@ import pytest
...
@@ -25,9 +25,11 @@ import pytest
import
torch
import
torch
from
pytorch_pretrained_bert
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
)
from
pytorch_pretrained_bert
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
from
pytorch_pretrained_bert.modeling_xlnet
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_pretrained_bert.modeling_xlnet
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
.model_tests_commons
import
ConfigTester
,
create_and_check_commons
,
ids_tensor
class
XLNetModelTest
(
unittest
.
TestCase
):
class
XLNetModelTest
(
unittest
.
TestCase
):
class
XLNetModelTester
(
object
):
class
XLNetModelTester
(
object
):
...
@@ -42,43 +44,48 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -42,43 +44,48 @@ class XLNetModelTest(unittest.TestCase):
use_labels
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
cutoffs
=
[
10
,
50
,
80
],
d_model
=
32
,
hidden_size
=
32
,
n_head
=
4
,
num_attentio
n_head
s
=
4
,
d_inner
=
128
,
d_inner
=
128
,
n_layer
=
5
,
num_hidde
n_layer
s
=
5
,
max_position_embeddings
=
10
,
max_position_embeddings
=
10
,
untie_r
=
True
,
untie_r
=
True
,
bi_data
=
False
,
bi_data
=
False
,
same_length
=
False
,
same_length
=
False
,
seed
=
1
,
seed
=
1
,
type_vocab_size
=
2
):
type_vocab_size
=
2
,
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
),
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
mem_len
=
mem_len
# self.key_len = seq_length + mem_len
self
.
clamp_len
=
clamp_len
self
.
clamp_len
=
clamp_len
self
.
reuse_len
=
reuse_len
self
.
reuse_len
=
reuse_len
self
.
is_training
=
is_training
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
cutoffs
=
cutoffs
self
.
d_model
=
d_model
self
.
hidden_size
=
hidden_size
self
.
n_head
=
n_head
self
.
n
um_attention
_head
s
=
num_attentio
n_head
s
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
n_layer
=
n_layer
self
.
n
um_hidden
_layer
s
=
num_hidde
n_layer
s
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
bi_data
=
bi_data
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
same_length
=
same_length
self
.
seed
=
seed
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_vocab_size
=
type_vocab_size
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_ids_q
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
...
@@ -89,8 +96,8 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -89,8 +96,8 @@ class XLNetModelTest(unittest.TestCase):
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [bsz, mem_len,
d_model
], memory
# mems: a list of float32 Tensors in shape [bsz, mem_len,
hidden_size
], memory
# from previous batches. The length of the list equals n_layer.
# from previous batches. The length of the list equals
num_hidde
n_layer
s
.
# If None, no memory is used.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [bsz, len, len].
# perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
...
@@ -108,14 +115,14 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -108,14 +115,14 @@ class XLNetModelTest(unittest.TestCase):
lm_labels
=
None
lm_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
lm_labels
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
XLNetConfig
(
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
d_model
=
self
.
d_model
,
d_model
=
self
.
hidden_size
,
n_head
=
self
.
n_head
,
n_head
=
self
.
n
um_attention
_head
s
,
d_inner
=
self
.
d_inner
,
d_inner
=
self
.
d_inner
,
n_layer
=
self
.
n_layer
,
n_layer
=
self
.
n
um_hidden
_layer
s
,
untie_r
=
self
.
untie_r
,
untie_r
=
self
.
untie_r
,
max_position_embeddings
=
self
.
max_position_embeddings
,
max_position_embeddings
=
self
.
max_position_embeddings
,
mem_len
=
self
.
mem_len
,
mem_len
=
self
.
mem_len
,
...
@@ -159,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -159,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
n
um_hidden
_layer
s
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
list
(
result
[
"loss_2"
].
size
()),
...
@@ -169,24 +176,18 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -169,24 +176,18 @@ class XLNetModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
def
test_config_to_json_string
(
self
):
def
test_config
(
self
):
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
16
*
4
)
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
config_tester
.
run_common_tests
()
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_model"
],
16
*
4
)
def
test_config_to_json_file
(
self
):
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
16
*
4
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
XLNetConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
...
@@ -197,27 +198,14 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -197,27 +198,14 @@ class XLNetModelTest(unittest.TestCase):
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
set_seed
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
@
classmethod
tester
.
set_seed
()
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
"""Creates a random int32 tensor of the shape within the vocab size."""
tester
.
create_and_check_xlnet_commons
(
*
config_and_inputs
)
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
@
classmethod
@
classmethod
def
mask_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
mask_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
tests/optimization_test.py
→
pytorch_pretrained_bert/
tests/optimization_test.py
View file @
1484d67d
File moved
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment