Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
70887795
Commit
70887795
authored
Jul 02, 2019
by
thomwolf
Browse files
updating tests and models, adding weights initialization test
parent
99ae5ab8
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
113 additions
and
108 deletions
+113
-108
pytorch_pretrained_bert/file_utils.py
pytorch_pretrained_bert/file_utils.py
+2
-0
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+1
-2
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+4
-4
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+7
-6
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+7
-6
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+4
-2
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+0
-1
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+7
-5
pytorch_pretrained_bert/tests/model_tests_commons.py
pytorch_pretrained_bert/tests/model_tests_commons.py
+69
-69
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
+5
-7
pytorch_pretrained_bert/tests/modeling_openai_test.py
pytorch_pretrained_bert/tests/modeling_openai_test.py
+1
-3
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
+1
-1
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
+5
-2
No files found.
pytorch_pretrained_bert/file_utils.py
View file @
70887795
...
...
@@ -191,6 +191,8 @@ def get_from_cache(url, cache_dir=None):
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
if
sys
.
version_info
[
0
]
==
2
and
not
isinstance
(
cache_dir
,
str
):
cache_dir
=
str
(
cache_dir
)
if
not
os
.
path
.
exists
(
cache_dir
):
os
.
makedirs
(
cache_dir
)
...
...
pytorch_pretrained_bert/model_utils.py
View file @
70887795
...
...
@@ -60,8 +60,7 @@ class PretrainedConfig(object):
. `config.json` a configuration file for the model
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
"""
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
...
...
pytorch_pretrained_bert/modeling.py
View file @
70887795
...
...
@@ -17,7 +17,6 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
json
import
logging
import
math
...
...
@@ -422,8 +421,7 @@ class BertEncoder(nn.Module):
super
(
BertEncoder
,
self
).
__init__
()
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
layer
=
BertLayer
(
config
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
BertLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
all_hidden_states
=
[]
...
...
@@ -539,10 +537,12 @@ class BertPreTrainedModel(PreTrainedModel):
"""
config_class
=
BertConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_bert
base_model_prefix
=
"bert"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
BertPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
"""
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
70887795
...
...
@@ -18,7 +18,6 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
collections
import
copy
import
json
import
logging
import
math
...
...
@@ -378,18 +377,21 @@ class GPT2PreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_gpt2
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
GPT2PreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
(
nn
.
Linear
,
Conv1D
))
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
...
...
@@ -489,8 +491,7 @@ class GPT2Model(GPT2PreTrainedModel):
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
apply
(
self
.
init_weights
)
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
70887795
...
...
@@ -18,7 +18,6 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
collections
import
copy
import
json
import
logging
import
math
...
...
@@ -405,18 +404,21 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_openai_gpt
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
(
nn
.
Linear
,
Conv1D
))
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
...
...
@@ -513,8 +515,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
for
_
in
range
(
config
.
n_layer
)])
self
.
apply
(
self
.
init_weights
)
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
70887795
...
...
@@ -21,7 +21,6 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
copy
import
json
import
math
import
logging
...
...
@@ -843,6 +842,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
...
...
@@ -883,7 +885,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
self
.
config
.
init_std
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
_init_bias
(
m
.
bias
)
el
if
classname
.
find
(
'TransformerLM'
)
!=
-
1
:
el
se
:
if
hasattr
(
m
,
'r_emb'
):
self
.
_init_weight
(
m
.
r_emb
)
if
hasattr
(
m
,
'r_w_bias'
):
...
...
pytorch_pretrained_bert/modeling_xlm.py
View file @
70887795
...
...
@@ -18,7 +18,6 @@ from __future__ import (absolute_import, division, print_function,
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
json
import
logging
import
math
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
70887795
...
...
@@ -19,7 +19,6 @@ from __future__ import (absolute_import, division, print_function,
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
json
import
logging
import
math
...
...
@@ -598,6 +597,8 @@ class XLNetPreTrainedModel(PreTrainedModel):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
XLNetLayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
...
...
@@ -606,8 +607,8 @@ class XLNetPreTrainedModel(PreTrainedModel):
module
.
r_r_bias
,
module
.
r_s_bias
,
module
.
r_w_bias
,
module
.
seg_embed
]:
param
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
(
)
el
if
isinstance
(
module
,
XLNetModel
)
:
module
.
mask_emb
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
class
XLNetModel
(
XLNetPreTrainedModel
):
...
...
@@ -627,10 +628,11 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
layer
=
XLNetLayer
(
config
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
layer
=
nn
.
ModuleList
([
XLNetLayer
(
config
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
apply
(
self
.
init_weights
)
def
_prune_heads
(
self
,
heads_to_prune
):
logger
.
info
(
"Head pruning is not implemented for XLNet"
)
pass
...
...
pytorch_pretrained_bert/tests/model_tests_commons.py
View file @
70887795
...
...
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
copy
import
os
import
shutil
import
json
...
...
@@ -23,87 +24,84 @@ import random
import
torch
def
create_and_check_for_headmasking
(
tester
,
model_classes
,
config
,
inputs_dict
):
def
_config_zero_init
(
config
):
configs_no_init
=
copy
.
deepcopy
(
config
)
for
key
in
configs_no_init
.
__dict__
.
keys
():
if
'_range'
in
key
or
'_std'
in
key
:
setattr
(
configs_no_init
,
key
,
0.0
)
return
configs_no_init
def
_create_and_check_initialization
(
tester
,
model_classes
,
config
,
inputs_dict
):
configs_no_init
=
_config_zero_init
(
config
)
for
model_class
in
model_classes
:
model
=
model_class
(
config
=
configs_no_init
)
for
name
,
param
in
model
.
named_parameters
():
tester
.
parent
.
assertIn
(
param
.
data
.
mean
().
item
(),
[
0.0
,
1.0
],
msg
=
"Parameter {} of model {} seems not properly initialized"
.
format
(
name
,
model_class
))
def
_create_and_check_for_headmasking
(
tester
,
model_classes
,
config
,
inputs_dict
):
configs_no_init
=
_config_zero_init
(
config
)
for
model_class
in
model_classes
:
config
.
output_attentions
=
True
config
.
output_hidden_states
=
True
model
=
model_class
(
config
=
config
)
model
=
model_class
(
config
=
config
s_no_init
)
model
.
eval
()
head_mask
=
torch
.
zeros
(
tester
.
num_hidden_layers
,
tester
.
num_attention_heads
)
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
# Prepare head_mask
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask
=
torch
.
ones
(
tester
.
num_hidden_layers
,
tester
.
num_attention_heads
)
head_mask
[
0
,
0
]
=
0
head_mask
[
-
1
,
:
-
1
]
=
0
head_mask
.
requires_grad_
(
requires_grad
=
True
)
outputs
=
model
(
**
inputs_dict
,
head_mask
=
head_mask
)
inputs
=
inputs_dict
.
copy
()
inputs
[
'head_mask'
]
=
head_mask
# Compute some gradients
outputs
=
model
(
**
inputs
)
# Test that we can get a gradient back for importance score computation
output
=
sum
(
t
.
sum
()
for
t
in
outputs
[
0
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
head_mask
.
grad
attentions
=
outputs
[
-
1
]
hidden_states
=
outputs
[
-
2
]
tester
.
parent
.
assertIsNotNone
(
multihead_outputs
)
tester
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
tester
.
num_hidden_layers
)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def
create_and_check_for_head_pruning
(
tester
,
model_classes
,
config
,
inputs_dict
):
tester
.
parent
.
assertAlmostEqual
(
attentions
[
0
][...,
0
,
:,
:].
flatten
().
sum
().
item
(),
0.0
)
tester
.
parent
.
assertNotEqual
(
attentions
[
0
][...,
-
1
,
:,
:].
flatten
().
sum
().
item
(),
0.0
)
tester
.
parent
.
assertNotEqual
(
attentions
[
1
][...,
0
,
:,
:].
flatten
().
sum
().
item
(),
0.0
)
tester
.
parent
.
assertAlmostEqual
(
attentions
[
-
1
][...,
-
2
,
:,
:].
flatten
().
sum
().
item
(),
0.0
)
tester
.
parent
.
assertNotEqual
(
attentions
[
-
1
][...,
-
1
,
:,
:].
flatten
().
sum
().
item
(),
0.0
)
def
_create_and_check_for_head_pruning
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
model
=
model_class
(
config
=
config
)
model
.
eval
()
heads_to_prune
=
{
0
:
list
(
range
(
1
,
tester
.
num_attention_heads
)),
-
1
:
[
0
]}
-
1
:
[
0
]}
model
.
prune_heads
(
heads_to_prune
)
outputs
=
model
(
**
inputs_dict
)
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = bert_model.get_multihead_outputs()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def
create_and_check_for_attentions
(
tester
,
model_classes
,
config
,
inputs_dict
):
attentions
=
outputs
[
-
1
]
tester
.
parent
.
assertEqual
(
attentions
[
0
].
shape
[
-
3
],
1
)
tester
.
parent
.
assertEqual
(
attentions
[
1
].
shape
[
-
3
],
tester
.
num_attention_heads
)
tester
.
parent
.
assertEqual
(
attentions
[
-
1
].
shape
[
-
3
],
tester
.
num_attention_heads
-
1
)
def
_create_and_check_for_attentions
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
...
...
@@ -139,7 +137,7 @@ def create_and_check_for_attentions(tester, model_classes, config, inputs_dict):
tester
.
seq_length
,
tester
.
key_len
if
hasattr
(
tester
,
'key_len'
)
else
tester
.
seq_length
])
def
create_and_check_for_hidden_states
(
tester
,
model_classes
,
config
,
inputs_dict
):
def
_
create_and_check_for_hidden_states
(
tester
,
model_classes
,
config
,
inputs_dict
):
for
model_class
in
model_classes
:
config
.
output_hidden_states
=
True
config
.
output_attentions
=
False
...
...
@@ -155,11 +153,13 @@ def create_and_check_for_hidden_states(tester, model_classes, config, inputs_dic
[
tester
.
seq_length
,
tester
.
hidden_size
])
def
create_and_check_commons
(
tester
,
config
,
inputs_dict
):
create_and_check_for_attentions
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_headmasking
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_head_pruning
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
create_and_check_for_hidden_states
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
def
create_and_check_commons
(
tester
,
config
,
inputs_dict
,
test_pruning
=
True
):
_create_and_check_initialization
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_attentions
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_headmasking
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_hidden_states
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
if
test_pruning
:
_create_and_check_for_head_pruning
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
pytorch_pretrained_bert/tests/modeling_gpt2_test.py
View file @
70887795
...
...
@@ -28,9 +28,7 @@ import torch
from
pytorch_pretrained_bert
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
from
.model_tests_commons
import
(
create_and_check_for_attentions
,
create_and_check_for_head_pruning
,
create_and_check_for_headmasking
,
create_and_check_for_hidden_states
,
ConfigTester
,
GPTModelTester
)
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
GPTModelTester
)
class
GPT2ModelTest
(
unittest
.
TestCase
):
...
...
@@ -40,15 +38,15 @@ class GPT2ModelTest(unittest.TestCase):
def
test_model
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_common_tests
(
test_presents
=
True
)
@
pytest
.
mark
.
slow
def
test_pretrained
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_slow_tests
()
if
__name__
==
"__main__"
:
...
...
pytorch_pretrained_bert/tests/modeling_openai_test.py
View file @
70887795
...
...
@@ -28,9 +28,7 @@ import torch
from
pytorch_pretrained_bert
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
.model_tests_commons
import
(
create_and_check_for_attentions
,
create_and_check_for_head_pruning
,
create_and_check_for_headmasking
,
create_and_check_for_hidden_states
,
ConfigTester
,
GPTModelTester
)
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
GPTModelTester
)
class
OpenAIModelTest
(
unittest
.
TestCase
):
...
...
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
View file @
70887795
...
...
@@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase):
def
create_and_check_transfo_xl_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
def
test_default
(
self
):
self
.
run_tester
(
TransfoXLModelTest
.
TransfoXLModelTester
(
self
))
...
...
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
View file @
70887795
...
...
@@ -52,6 +52,7 @@ class XLNetModelTest(unittest.TestCase):
untie_r
=
True
,
bi_data
=
False
,
same_length
=
False
,
initializer_range
=
0.05
,
seed
=
1
,
type_vocab_size
=
2
,
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
...
...
@@ -76,6 +77,7 @@ class XLNetModelTest(unittest.TestCase):
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
initializer_range
=
initializer_range
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
all_model_classes
=
all_model_classes
...
...
@@ -129,7 +131,8 @@ class XLNetModelTest(unittest.TestCase):
clamp_len
=
self
.
clamp_len
,
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
)
bi_data
=
self
.
bi_data
,
initializer_range
=
self
.
initializer_range
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
)
...
...
@@ -180,7 +183,7 @@ class XLNetModelTest(unittest.TestCase):
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
def
test_default
(
self
):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment