Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8fa3a1f0
Commit
8fa3a1f0
authored
Jul 03, 2019
by
thomwolf
Browse files
updating tests
parent
c41f2bad
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
32 deletions
+69
-32
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+31
-23
pytorch_pretrained_bert/tests/model_tests_commons.py
pytorch_pretrained_bert/tests/model_tests_commons.py
+6
-0
pytorch_pretrained_bert/tests/modeling_bert_test.py
pytorch_pretrained_bert/tests/modeling_bert_test.py
+2
-0
pytorch_pretrained_bert/tests/modeling_xlm_test.py
pytorch_pretrained_bert/tests/modeling_xlm_test.py
+4
-2
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
+26
-6
pytorch_pretrained_bert/tokenization_xlm.py
pytorch_pretrained_bert/tokenization_xlm.py
+0
-1
No files found.
pytorch_pretrained_bert/modeling_xlm.py
View file @
8fa3a1f0
...
...
@@ -35,7 +35,7 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -46,24 +46,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json"
,
}
DECODER_ONLY_PARAMS
=
[
'layer_norm15.%i.weight'
,
'layer_norm15.%i.bias'
,
'encoder_attn.%i.q_lin.weight'
,
'encoder_attn.%i.q_lin.bias'
,
'encoder_attn.%i.k_lin.weight'
,
'encoder_attn.%i.k_lin.bias'
,
'encoder_attn.%i.v_lin.weight'
,
'encoder_attn.%i.v_lin.bias'
,
'encoder_attn.%i.out_lin.weight'
,
'encoder_attn.%i.out_lin.bias'
]
TRANSFORMER_LAYER_PARAMS
=
[
'attentions.%i.q_lin.weight'
,
'attentions.%i.q_lin.bias'
,
'attentions.%i.k_lin.weight'
,
'attentions.%i.k_lin.bias'
,
'attentions.%i.v_lin.weight'
,
'attentions.%i.v_lin.bias'
,
'attentions.%i.out_lin.weight'
,
'attentions.%i.out_lin.bias'
,
'layer_norm1.%i.weight'
,
'layer_norm1.%i.bias'
,
'ffns.%i.lin1.weight'
,
'ffns.%i.lin1.bias'
,
'ffns.%i.lin2.weight'
,
'ffns.%i.lin2.bias'
,
'layer_norm2.%i.weight'
,
'layer_norm2.%i.bias'
]
class
XLMConfig
(
PretrainedConfig
):
"""Configuration class to store the configuration of a `XLMModel`.
...
...
@@ -275,6 +257,24 @@ class MultiHeadAttention(nn.Module):
self
.
v_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
out_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
for
head
in
heads
:
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
self
.
q_lin
=
prune_linear_layer
(
self
.
q_lin
,
index
)
self
.
k_lin
=
prune_linear_layer
(
self
.
k_lin
,
index
)
self
.
v_lin
=
prune_linear_layer
(
self
.
v_lin
,
index
)
self
.
out_lin
=
prune_linear_layer
(
self
.
out_lin
,
index
,
dim
=
1
)
# Update hyper params
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
dim
=
attention_head_size
*
self
.
n_heads
def
forward
(
self
,
input
,
mask
,
kv
=
None
,
cache
=
None
,
head_mask
=
None
):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
...
...
@@ -286,9 +286,9 @@ class MultiHeadAttention(nn.Module):
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
else
:
klen
=
kv
.
size
(
1
)
assert
dim
==
self
.
dim
,
'Dimensions do not match: %s input vs %s configured'
%
(
dim
,
self
.
dim
)
#
assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads
=
self
.
n_heads
dim_per_head
=
dim
//
n_heads
dim_per_head
=
self
.
dim
//
n_heads
mask_reshape
=
(
bs
,
1
,
qlen
,
klen
)
if
mask
.
dim
()
==
3
else
(
bs
,
1
,
1
,
klen
)
def
shape
(
x
):
...
...
@@ -335,7 +335,7 @@ class MultiHeadAttention(nn.Module):
outputs
=
(
self
.
out_lin
(
context
),)
if
self
.
output_attentions
:
outputs
=
outputs
+
(
weights
)
outputs
=
outputs
+
(
weights
,
)
return
outputs
...
...
@@ -497,6 +497,14 @@ class XLMModel(XLMPreTrainedModel):
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
attentions
[
layer
].
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
head_mask
=
None
):
# src_enc=None, src_len=None,
"""
...
...
@@ -508,7 +516,7 @@ class XLMModel(XLMPreTrainedModel):
`token_type_ids` LongTensor (bs, slen) same as `langs` used for compatibility
"""
if
lengths
is
None
:
lengths
=
(
input_ids
!=
self
.
pad_index
).
float
().
sum
(
dim
=
1
)
lengths
=
(
input_ids
!=
self
.
pad_index
).
sum
(
dim
=
1
)
.
long
()
# mask = input_ids != self.pad_index
# check inputs
...
...
pytorch_pretrained_bert/tests/model_tests_commons.py
View file @
8fa3a1f0
...
...
@@ -68,6 +68,8 @@ def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict
attentions
=
outputs
[
-
1
]
hidden_states
=
outputs
[
-
2
]
# Remove Nan
tester
.
parent
.
assertIsNotNone
(
multihead_outputs
)
tester
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
tester
.
num_hidden_layers
)
tester
.
parent
.
assertAlmostEqual
(
...
...
@@ -298,7 +300,11 @@ class GPTModelTester(object):
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
self
.
base_model_class
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
model
(
input_ids
,
position_ids
)
outputs
=
model
(
input_ids
)
hidden_state
=
outputs
[
0
]
self
.
parent
.
assertListEqual
(
list
(
hidden_state
.
size
()),
...
...
pytorch_pretrained_bert/tests/modeling_bert_test.py
View file @
8fa3a1f0
...
...
@@ -126,6 +126,8 @@ class BertModelTest(unittest.TestCase):
model
=
BertModel
(
config
=
config
)
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
...
...
pytorch_pretrained_bert/tests/modeling_xlm_test.py
View file @
8fa3a1f0
...
...
@@ -96,7 +96,7 @@ class XLMModelTest(unittest.TestCase):
input_lengths
=
None
if
self
.
use_input_lengths
:
input_lengths
=
ids_tensor
([
self
.
batch_size
],
vocab_size
=
self
.
seq_length
-
1
)
input_lengths
=
ids_tensor
([
self
.
batch_size
],
vocab_size
=
2
)
+
self
.
seq_length
-
2
# small variation of seq_length
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
@@ -139,6 +139,8 @@ class XLMModelTest(unittest.TestCase):
model
=
XLMModel
(
config
=
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
)
sequence_output
=
outputs
[
0
]
result
=
{
"sequence_output"
:
sequence_output
,
...
...
@@ -232,7 +234,7 @@ class XLMModelTest(unittest.TestCase):
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'
attention_mask
'
:
input_lengths
}
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'
lengths
'
:
input_lengths
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
def
test_default
(
self
):
...
...
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
View file @
8fa3a1f0
...
...
@@ -140,7 +140,26 @@ class XLNetModelTest(unittest.TestCase):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
model
=
XLNetModel
(
config
)
model
.
eval
()
_
,
_
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
outputs
,
mems_1
=
model
(
input_ids_1
)
result
=
{
"mems_1"
:
mems_1
,
"outputs"
:
outputs
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"outputs"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
...
...
@@ -150,7 +169,7 @@ class XLNetModelTest(unittest.TestCase):
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
inp_q
=
inp_q
)
outputs
=
{
result
=
{
"loss_1"
:
loss_1
,
"mems_1"
:
mems_1
,
"all_logits_1"
:
all_logits_1
,
...
...
@@ -158,9 +177,7 @@ class XLNetModelTest(unittest.TestCase):
"mems_2"
:
mems_2
,
"all_logits_2"
:
all_logits_2
,
}
return
outputs
def
check_transfo_xl_lm_head_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[])
...
...
@@ -203,8 +220,11 @@ class XLNetModelTest(unittest.TestCase):
def
run_tester
(
self
,
tester
):
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
tester
.
create_and_check_xlnet_base_model
(
*
config_and_inputs
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
...
...
pytorch_pretrained_bert/tokenization_xlm.py
View file @
8fa3a1f0
...
...
@@ -304,7 +304,6 @@ class XLMTokenizer(object):
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment