Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
70887795
Commit
70887795
authored
Jul 02, 2019
by
thomwolf
Browse files
updating tests and models, adding weights initialization test
parent
99ae5ab8
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
113 additions
and
108 deletions
+113
-108
pytorch_pretrained_bert/file_utils.py
pytorch_pretrained_bert/file_utils.py
+2
-0
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+1
-2
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+4
-4
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+7
-6
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+7
-6
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+4
-2
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+0
-1
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+7
-5
pytorch_pretrained_bert/tests/model_tests_commons.py
pytorch_pretrained_bert/tests/model_tests_commons.py
+69
-69
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
+5
-7
pytorch_pretrained_bert/tests/modeling_openai_test.py
pytorch_pretrained_bert/tests/modeling_openai_test.py
+1
-3
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
+1
-1
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
+5
-2
No files found.
pytorch_pretrained_bert/file_utils.py
View file @
70887795
...
@@ -191,6 +191,8 @@ def get_from_cache(url, cache_dir=None):
...
@@ -191,6 +191,8 @@ def get_from_cache(url, cache_dir=None):
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
cache_dir
,
Path
):
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
cache_dir
=
str
(
cache_dir
)
if
sys
.
version_info
[
0
]
==
2
and
not
isinstance
(
cache_dir
,
str
):
cache_dir
=
str
(
cache_dir
)
if
not
os
.
path
.
exists
(
cache_dir
):
if
not
os
.
path
.
exists
(
cache_dir
):
os
.
makedirs
(
cache_dir
)
os
.
makedirs
(
cache_dir
)
...
...
pytorch_pretrained_bert/model_utils.py
View file @
70887795
...
@@ -60,8 +60,7 @@ class PretrainedConfig(object):
...
@@ -60,8 +60,7 @@ class PretrainedConfig(object):
. `config.json` a configuration file for the model
. `config.json` a configuration file for the model
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
"""
"""
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
...
...
pytorch_pretrained_bert/modeling.py
View file @
70887795
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
json
import
json
import
logging
import
logging
import
math
import
math
...
@@ -422,8 +421,7 @@ class BertEncoder(nn.Module):
...
@@ -422,8 +421,7 @@ class BertEncoder(nn.Module):
super
(
BertEncoder
,
self
).
__init__
()
super
(
BertEncoder
,
self
).
__init__
()
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_hidden_states
=
config
.
output_hidden_states
layer
=
BertLayer
(
config
)
self
.
layer
=
nn
.
ModuleList
([
BertLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
all_hidden_states
=
[]
all_hidden_states
=
[]
...
@@ -539,10 +537,12 @@ class BertPreTrainedModel(PreTrainedModel):
...
@@ -539,10 +537,12 @@ class BertPreTrainedModel(PreTrainedModel):
"""
"""
config_class
=
BertConfig
config_class
=
BertConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_bert
load_tf_weights
=
load_tf_weights_in_bert
base_model_prefix
=
"bert"
base_model_prefix
=
"bert"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
BertPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
70887795
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
collections
import
collections
import
copy
import
json
import
json
import
logging
import
logging
import
math
import
math
...
@@ -378,18 +377,21 @@ class GPT2PreTrainedModel(PreTrainedModel):
...
@@ -378,18 +377,21 @@ class GPT2PreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_gpt2
load_tf_weights
=
load_tf_weights_in_gpt2
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
GPT2PreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
(
nn
.
Linear
,
Conv1D
))
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
LayerNorm
):
elif
isinstance
(
module
,
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
...
@@ -489,8 +491,7 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -489,8 +491,7 @@ class GPT2Model(GPT2PreTrainedModel):
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
self
.
h
=
nn
.
ModuleList
([
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
70887795
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
collections
import
collections
import
copy
import
json
import
json
import
logging
import
logging
import
math
import
math
...
@@ -405,18 +404,21 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
...
@@ -405,18 +404,21 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_openai_gpt
load_tf_weights
=
load_tf_weights_in_openai_gpt
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
(
nn
.
Linear
,
Conv1D
))
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
LayerNorm
):
elif
isinstance
(
module
,
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
...
@@ -513,8 +515,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -513,8 +515,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
self
.
h
=
nn
.
ModuleList
([
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
70887795
...
@@ -21,7 +21,6 @@
...
@@ -21,7 +21,6 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
os
import
copy
import
json
import
json
import
math
import
math
import
logging
import
logging
...
@@ -843,6 +842,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
...
@@ -843,6 +842,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_transfo_xl
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weight
(
self
,
weight
):
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
...
@@ -883,7 +885,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
...
@@ -883,7 +885,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
self
.
config
.
init_std
)
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
self
.
config
.
init_std
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
_init_bias
(
m
.
bias
)
self
.
_init_bias
(
m
.
bias
)
el
if
classname
.
find
(
'TransformerLM'
)
!=
-
1
:
el
se
:
if
hasattr
(
m
,
'r_emb'
):
if
hasattr
(
m
,
'r_emb'
):
self
.
_init_weight
(
m
.
r_emb
)
self
.
_init_weight
(
m
.
r_emb
)
if
hasattr
(
m
,
'r_w_bias'
):
if
hasattr
(
m
,
'r_w_bias'
):
...
...
pytorch_pretrained_bert/modeling_xlm.py
View file @
70887795
...
@@ -18,7 +18,6 @@ from __future__ import (absolute_import, division, print_function,
...
@@ -18,7 +18,6 @@ from __future__ import (absolute_import, division, print_function,
unicode_literals
)
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
json
import
json
import
logging
import
logging
import
math
import
math
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
70887795
...
@@ -19,7 +19,6 @@ from __future__ import (absolute_import, division, print_function,
...
@@ -19,7 +19,6 @@ from __future__ import (absolute_import, division, print_function,
unicode_literals
)
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
json
import
json
import
logging
import
logging
import
math
import
math
...
@@ -598,6 +597,8 @@ class XLNetPreTrainedModel(PreTrainedModel):
...
@@ -598,6 +597,8 @@ class XLNetPreTrainedModel(PreTrainedModel):
# Slightly different from the TF version which uses truncated_normal for initialization
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
XLNetLayerNorm
):
elif
isinstance
(
module
,
XLNetLayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
module
.
weight
.
data
.
fill_
(
1.0
)
...
@@ -606,8 +607,8 @@ class XLNetPreTrainedModel(PreTrainedModel):
...
@@ -606,8 +607,8 @@ class XLNetPreTrainedModel(PreTrainedModel):
module
.
r_r_bias
,
module
.
r_s_bias
,
module
.
r_w_bias
,
module
.
r_r_bias
,
module
.
r_s_bias
,
module
.
r_w_bias
,
module
.
seg_embed
]:
module
.
seg_embed
]:
param
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
param
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
el
if
isinstance
(
module
,
XLNetModel
)
:
module
.
bias
.
data
.
zero_
(
)
module
.
mask_emb
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
class
XLNetModel
(
XLNetPreTrainedModel
):
class
XLNetModel
(
XLNetPreTrainedModel
):
...
@@ -627,10 +628,11 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -627,10 +628,11 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
layer
=
XLNetLayer
(
config
)
self
.
layer
=
nn
.
ModuleList
([
XLNetLayer
(
config
)
for
_
in
range
(
config
.
n_layer
)])
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
apply
(
self
.
init_weights
)
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
logger
.
info
(
"Head pruning is not implemented for XLNet"
)
logger
.
info
(
"Head pruning is not implemented for XLNet"
)
pass
pass
...
...
pytorch_pretrained_bert/tests/model_tests_commons.py
View file @
70887795
...
@@ -16,6 +16,7 @@ from __future__ import absolute_import
...
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
copy
import
os
import
os
import
shutil
import
shutil
import
json
import
json
...
@@ -23,87 +24,84 @@ import random
...
@@ -23,87 +24,84 @@ import random
import
torch
import
torch
def
create_and_check_for_headmasking
(
tester
,
model_classes
,
config
,
inputs_dict
):
def
_config_zero_init
(
config
):
configs_no_init
=
copy
.
deepcopy
(
config
)
for
key
in
configs_no_init
.
__dict__
.
keys
():
if
'_range'
in
key
or
'_std'
in
key
:
setattr
(
configs_no_init
,
key
,
0.0
)
return
configs_no_init
def
_create_and_check_initialization
(
tester
,
model_classes
,
config
,
inputs_dict
):
configs_no_init
=
_config_zero_init
(
config
)
for
model_class
in
model_classes
:
model
=
model_class
(
config
=
configs_no_init
)
for
name
,
param
in
model
.
named_parameters
():
tester
.
parent
.
assertIn
(
param
.
data
.
mean
().
item
(),
[
0.0
,
1.0
],
msg
=
"Parameter {} of model {} seems not properly initialized"
.
format
(
name
,
model_class
))
def
_create_and_check_for_headmasking
(
tester
,
model_classes
,
config
,
inputs_dict
):
configs_no_init
=
_config_zero_init
(
config
)
for
model_class
in
model_classes
:
for
model_class
in
model_classes
:
config
.
output_attentions
=
True
config
.
output_hidden_states
=
True
config
.
output_hidden_states
=
True
model
=
model_class
(
config
=
config
)
model
=
model_class
(
config
=
config
s_no_init
)
model
.
eval
()
model
.
eval
()
head_mask
=
torch
.
zeros
(
tester
.
num_hidden_layers
,
tester
.
num_attention_heads
)
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
# Prepare head_mask
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask
=
torch
.
ones
(
tester
.
num_hidden_layers
,
tester
.
num_attention_heads
)
head_mask
[
0
,
0
]
=
0
head_mask
[
-
1
,
:
-
1
]
=
0
head_mask
.
requires_grad_
(
requires_grad
=
True
)
head_mask
.
requires_grad_
(
requires_grad
=
True
)
outputs
=
model
(
**
inputs_dict
,
head_mask
=
head_mask
)
inputs
=
inputs_dict
.
copy
()
inputs
[
'head_mask'
]
=
head_mask
# Compute some gradients
outputs
=
model
(
**
inputs
)
# Test that we can get a gradient back for importance score computation
output
=
sum
(
t
.
sum
()
for
t
in
outputs
[
0
])
output
=
sum
(
t
.
sum
()
for
t
in
outputs
[
0
])
output
=
output
.
sum
()
output
=
output
.
sum
()
output
.
backward
()
output
.
backward
()
multihead_outputs
=
head_mask
.
grad
multihead_outputs
=
head_mask
.
grad
attentions
=
outputs
[
-
1
]
hidden_states
=
outputs
[
-
2
]
tester
.
parent
.
assertIsNotNone
(
multihead_outputs
)
tester
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
tester
.
num_hidden_layers
)
tester
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
tester
.
num_hidden_layers
)
# self.parent.assertListEqual(
tester
.
parent
.
assertAlmostEqual
(
# list(multihead_outputs[0].size()),
attentions
[
0
][...,
0
,
:,
:].
flatten
().
sum
().
item
(),
0.0
)
# [self.batch_size, self.num_attention_heads,
tester
.
parent
.
assertNotEqual
(
# self.seq_length, self.hidden_size // self.num_attention_heads])
attentions
[
0
][...,
-
1
,
:,
:].
flatten
().
sum
().
item
(),
0.0
)
# self.parent.assertEqual(
tester
.
parent
.
assertNotEqual
(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
attentions
[
1
][...,
0
,
:,
:].
flatten
().
sum
().
item
(),
0.0
)
# 0)
tester
.
parent
.
assertAlmostEqual
(
# self.parent.assertEqual(
attentions
[
-
1
][...,
-
2
,
:,
:].
flatten
().
sum
().
item
(),
0.0
)
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
tester
.
parent
.
assertNotEqual
(
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
attentions
[
-
1
][...,
-
1
,
:,
:].
flatten
().
sum
().
item
(),
0.0
)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def
_create_and_check_for_head_pruning
(
tester
,
model_classes
,
config
,
inputs_dict
):
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def
create_and_check_for_head_pruning
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
for
model_class
in
model_classes
:
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
model
=
model_class
(
config
=
config
)
model
=
model_class
(
config
=
config
)
model
.
eval
()
model
.
eval
()
heads_to_prune
=
{
0
:
list
(
range
(
1
,
tester
.
num_attention_heads
)),
heads_to_prune
=
{
0
:
list
(
range
(
1
,
tester
.
num_attention_heads
)),
-
1
:
[
0
]}
-
1
:
[
0
]}
model
.
prune_heads
(
heads_to_prune
)
model
.
prune_heads
(
heads_to_prune
)
outputs
=
model
(
**
inputs_dict
)
outputs
=
model
(
**
inputs_dict
)
# output = sum(t.sum() for t in outputs[0])
attentions
=
outputs
[
-
1
]
# output = output.sum()
# output.backward()
tester
.
parent
.
assertEqual
(
# multihead_outputs = bert_model.get_multihead_outputs()
attentions
[
0
].
shape
[
-
3
],
1
)
tester
.
parent
.
assertEqual
(
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
attentions
[
1
].
shape
[
-
3
],
tester
.
num_attention_heads
)
# self.parent.assertListEqual(
tester
.
parent
.
assertEqual
(
# list(multihead_outputs[0].size()),
attentions
[
-
1
].
shape
[
-
3
],
tester
.
num_attention_heads
-
1
)
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
def
_create_and_check_for_attentions
(
tester
,
model_classes
,
config
,
inputs_dict
):
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def
create_and_check_for_attentions
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
for
model_class
in
model_classes
:
config
.
output_attentions
=
True
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
config
.
output_hidden_states
=
False
...
@@ -139,7 +137,7 @@ def create_and_check_for_attentions(tester, model_classes, config, inputs_dict):
...
@@ -139,7 +137,7 @@ def create_and_check_for_attentions(tester, model_classes, config, inputs_dict):
tester
.
seq_length
,
tester
.
seq_length
,
tester
.
key_len
if
hasattr
(
tester
,
'key_len'
)
else
tester
.
seq_length
])
tester
.
key_len
if
hasattr
(
tester
,
'key_len'
)
else
tester
.
seq_length
])
def
create_and_check_for_hidden_states
(
tester
,
model_classes
,
config
,
inputs_dict
):
def
_
create_and_check_for_hidden_states
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
for
model_class
in
model_classes
:
config
.
output_hidden_states
=
True
config
.
output_hidden_states
=
True
config
.
output_attentions
=
False
config
.
output_attentions
=
False
...
@@ -155,11 +153,13 @@ def create_and_check_for_hidden_states(tester, model_classes, config, inputs_dic
...
@@ -155,11 +153,13 @@ def create_and_check_for_hidden_states(tester, model_classes, config, inputs_dic
[
tester
.
seq_length
,
tester
.
hidden_size
])
[
tester
.
seq_length
,
tester
.
hidden_size
])
def
create_and_check_commons
(
tester
,
config
,
inputs_dict
):
def
create_and_check_commons
(
tester
,
config
,
inputs_dict
,
test_pruning
=
True
):
create_and_check_for_attentions
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_initialization
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_headmasking
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_attentions
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_head_pruning
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_headmasking
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_hidden_states
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_hidden_states
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
if
test_pruning
:
_create_and_check_for_head_pruning
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
View file @
70887795
...
@@ -28,9 +28,7 @@ import torch
...
@@ -28,9 +28,7 @@ import torch
from
pytorch_pretrained_bert
import
(
GPT2Config
,
GPT2Model
,
from
pytorch_pretrained_bert
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
from
.model_tests_commons
import
(
create_and_check_for_attentions
,
create_and_check_for_head_pruning
,
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
GPTModelTester
)
create_and_check_for_headmasking
,
create_and_check_for_hidden_states
,
ConfigTester
,
GPTModelTester
)
class
GPT2ModelTest
(
unittest
.
TestCase
):
class
GPT2ModelTest
(
unittest
.
TestCase
):
...
@@ -40,15 +38,15 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -40,15 +38,15 @@ class GPT2ModelTest(unittest.TestCase):
def
test_model
(
self
):
def
test_model
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_common_tests
(
test_presents
=
True
)
model_tester
.
run_common_tests
(
test_presents
=
True
)
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_pretrained
(
self
):
def
test_pretrained
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_slow_tests
()
model_tester
.
run_slow_tests
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
pytorch_pretrained_bert/tests/modeling_openai_test.py
View file @
70887795
...
@@ -28,9 +28,7 @@ import torch
...
@@ -28,9 +28,7 @@ import torch
from
pytorch_pretrained_bert
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
from
pytorch_pretrained_bert
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
.model_tests_commons
import
(
create_and_check_for_attentions
,
create_and_check_for_head_pruning
,
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
GPTModelTester
)
create_and_check_for_headmasking
,
create_and_check_for_hidden_states
,
ConfigTester
,
GPTModelTester
)
class
OpenAIModelTest
(
unittest
.
TestCase
):
class
OpenAIModelTest
(
unittest
.
TestCase
):
...
...
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
View file @
70887795
...
@@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase):
...
@@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase):
def
create_and_check_transfo_xl_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
def
create_and_check_transfo_xl_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
TransfoXLModelTest
.
TransfoXLModelTester
(
self
))
self
.
run_tester
(
TransfoXLModelTest
.
TransfoXLModelTester
(
self
))
...
...
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
View file @
70887795
...
@@ -52,6 +52,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -52,6 +52,7 @@ class XLNetModelTest(unittest.TestCase):
untie_r
=
True
,
untie_r
=
True
,
bi_data
=
False
,
bi_data
=
False
,
same_length
=
False
,
same_length
=
False
,
initializer_range
=
0.05
,
seed
=
1
,
seed
=
1
,
type_vocab_size
=
2
,
type_vocab_size
=
2
,
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
...
@@ -76,6 +77,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -76,6 +77,7 @@ class XLNetModelTest(unittest.TestCase):
self
.
bi_data
=
bi_data
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
same_length
=
same_length
self
.
initializer_range
=
initializer_range
self
.
seed
=
seed
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_vocab_size
=
type_vocab_size
self
.
all_model_classes
=
all_model_classes
self
.
all_model_classes
=
all_model_classes
...
@@ -129,7 +131,8 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -129,7 +131,8 @@ class XLNetModelTest(unittest.TestCase):
clamp_len
=
self
.
clamp_len
,
clamp_len
=
self
.
clamp_len
,
same_length
=
self
.
same_length
,
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
)
bi_data
=
self
.
bi_data
,
initializer_range
=
self
.
initializer_range
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
)
...
@@ -180,7 +183,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -180,7 +183,7 @@ class XLNetModelTest(unittest.TestCase):
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment