Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8fa3a1f0
Commit
8fa3a1f0
authored
Jul 03, 2019
by
thomwolf
Browse files
updating tests
parent
c41f2bad
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
32 deletions
+69
-32
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+31
-23
pytorch_pretrained_bert/tests/model_tests_commons.py
pytorch_pretrained_bert/tests/model_tests_commons.py
+6
-0
pytorch_pretrained_bert/tests/modeling_bert_test.py
pytorch_pretrained_bert/tests/modeling_bert_test.py
+2
-0
pytorch_pretrained_bert/tests/modeling_xlm_test.py
pytorch_pretrained_bert/tests/modeling_xlm_test.py
+4
-2
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
+26
-6
pytorch_pretrained_bert/tokenization_xlm.py
pytorch_pretrained_bert/tokenization_xlm.py
+0
-1
No files found.
pytorch_pretrained_bert/modeling_xlm.py
View file @
8fa3a1f0
...
@@ -35,7 +35,7 @@ from torch.nn import functional as F
...
@@ -35,7 +35,7 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -46,24 +46,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
...
@@ -46,24 +46,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json"
,
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json"
,
}
}
DECODER_ONLY_PARAMS
=
[
'layer_norm15.%i.weight'
,
'layer_norm15.%i.bias'
,
'encoder_attn.%i.q_lin.weight'
,
'encoder_attn.%i.q_lin.bias'
,
'encoder_attn.%i.k_lin.weight'
,
'encoder_attn.%i.k_lin.bias'
,
'encoder_attn.%i.v_lin.weight'
,
'encoder_attn.%i.v_lin.bias'
,
'encoder_attn.%i.out_lin.weight'
,
'encoder_attn.%i.out_lin.bias'
]
TRANSFORMER_LAYER_PARAMS
=
[
'attentions.%i.q_lin.weight'
,
'attentions.%i.q_lin.bias'
,
'attentions.%i.k_lin.weight'
,
'attentions.%i.k_lin.bias'
,
'attentions.%i.v_lin.weight'
,
'attentions.%i.v_lin.bias'
,
'attentions.%i.out_lin.weight'
,
'attentions.%i.out_lin.bias'
,
'layer_norm1.%i.weight'
,
'layer_norm1.%i.bias'
,
'ffns.%i.lin1.weight'
,
'ffns.%i.lin1.bias'
,
'ffns.%i.lin2.weight'
,
'ffns.%i.lin2.bias'
,
'layer_norm2.%i.weight'
,
'layer_norm2.%i.bias'
]
class
XLMConfig
(
PretrainedConfig
):
class
XLMConfig
(
PretrainedConfig
):
"""Configuration class to store the configuration of a `XLMModel`.
"""Configuration class to store the configuration of a `XLMModel`.
...
@@ -275,6 +257,24 @@ class MultiHeadAttention(nn.Module):
...
@@ -275,6 +257,24 @@ class MultiHeadAttention(nn.Module):
self
.
v_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
v_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
out_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
out_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
for
head
in
heads
:
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
self
.
q_lin
=
prune_linear_layer
(
self
.
q_lin
,
index
)
self
.
k_lin
=
prune_linear_layer
(
self
.
k_lin
,
index
)
self
.
v_lin
=
prune_linear_layer
(
self
.
v_lin
,
index
)
self
.
out_lin
=
prune_linear_layer
(
self
.
out_lin
,
index
,
dim
=
1
)
# Update hyper params
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
dim
=
attention_head_size
*
self
.
n_heads
def
forward
(
self
,
input
,
mask
,
kv
=
None
,
cache
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input
,
mask
,
kv
=
None
,
cache
=
None
,
head_mask
=
None
):
"""
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
Self-attention (if kv is None) or attention over source sentence (provided by kv).
...
@@ -286,9 +286,9 @@ class MultiHeadAttention(nn.Module):
...
@@ -286,9 +286,9 @@ class MultiHeadAttention(nn.Module):
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
else
:
else
:
klen
=
kv
.
size
(
1
)
klen
=
kv
.
size
(
1
)
assert
dim
==
self
.
dim
,
'Dimensions do not match: %s input vs %s configured'
%
(
dim
,
self
.
dim
)
#
assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads
=
self
.
n_heads
n_heads
=
self
.
n_heads
dim_per_head
=
dim
//
n_heads
dim_per_head
=
self
.
dim
//
n_heads
mask_reshape
=
(
bs
,
1
,
qlen
,
klen
)
if
mask
.
dim
()
==
3
else
(
bs
,
1
,
1
,
klen
)
mask_reshape
=
(
bs
,
1
,
qlen
,
klen
)
if
mask
.
dim
()
==
3
else
(
bs
,
1
,
1
,
klen
)
def
shape
(
x
):
def
shape
(
x
):
...
@@ -335,7 +335,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -335,7 +335,7 @@ class MultiHeadAttention(nn.Module):
outputs
=
(
self
.
out_lin
(
context
),)
outputs
=
(
self
.
out_lin
(
context
),)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
outputs
=
outputs
+
(
weights
)
outputs
=
outputs
+
(
weights
,
)
return
outputs
return
outputs
...
@@ -497,6 +497,14 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -497,6 +497,14 @@ class XLMModel(XLMPreTrainedModel):
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
))
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
attentions
[
layer
].
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
head_mask
=
None
):
# src_enc=None, src_len=None,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
head_mask
=
None
):
# src_enc=None, src_len=None,
"""
"""
...
@@ -508,7 +516,7 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -508,7 +516,7 @@ class XLMModel(XLMPreTrainedModel):
`token_type_ids` LongTensor (bs, slen) same as `langs` used for compatibility
`token_type_ids` LongTensor (bs, slen) same as `langs` used for compatibility
"""
"""
if
lengths
is
None
:
if
lengths
is
None
:
lengths
=
(
input_ids
!=
self
.
pad_index
).
float
().
sum
(
dim
=
1
)
lengths
=
(
input_ids
!=
self
.
pad_index
).
sum
(
dim
=
1
)
.
long
()
# mask = input_ids != self.pad_index
# mask = input_ids != self.pad_index
# check inputs
# check inputs
...
...
pytorch_pretrained_bert/tests/model_tests_commons.py
View file @
8fa3a1f0
...
@@ -68,6 +68,8 @@ def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict
...
@@ -68,6 +68,8 @@ def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict
attentions
=
outputs
[
-
1
]
attentions
=
outputs
[
-
1
]
hidden_states
=
outputs
[
-
2
]
hidden_states
=
outputs
[
-
2
]
# Remove Nan
tester
.
parent
.
assertIsNotNone
(
multihead_outputs
)
tester
.
parent
.
assertIsNotNone
(
multihead_outputs
)
tester
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
tester
.
num_hidden_layers
)
tester
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
tester
.
num_hidden_layers
)
tester
.
parent
.
assertAlmostEqual
(
tester
.
parent
.
assertAlmostEqual
(
...
@@ -298,7 +300,11 @@ class GPTModelTester(object):
...
@@ -298,7 +300,11 @@ class GPTModelTester(object):
mc_labels
,
lm_labels
,
mc_token_ids
):
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
self
.
base_model_class
(
config
)
model
=
self
.
base_model_class
(
config
)
model
.
eval
()
model
.
eval
()
outputs
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
model
(
input_ids
,
position_ids
)
outputs
=
model
(
input_ids
)
hidden_state
=
outputs
[
0
]
hidden_state
=
outputs
[
0
]
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
hidden_state
.
size
()),
list
(
hidden_state
.
size
()),
...
...
pytorch_pretrained_bert/tests/modeling_bert_test.py
View file @
8fa3a1f0
...
@@ -126,6 +126,8 @@ class BertModelTest(unittest.TestCase):
...
@@ -126,6 +126,8 @@ class BertModelTest(unittest.TestCase):
model
=
BertModel
(
config
=
config
)
model
=
BertModel
(
config
=
config
)
model
.
eval
()
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
result
=
{
result
=
{
"sequence_output"
:
sequence_output
,
"sequence_output"
:
sequence_output
,
...
...
pytorch_pretrained_bert/tests/modeling_xlm_test.py
View file @
8fa3a1f0
...
@@ -96,7 +96,7 @@ class XLMModelTest(unittest.TestCase):
...
@@ -96,7 +96,7 @@ class XLMModelTest(unittest.TestCase):
input_lengths
=
None
input_lengths
=
None
if
self
.
use_input_lengths
:
if
self
.
use_input_lengths
:
input_lengths
=
ids_tensor
([
self
.
batch_size
],
vocab_size
=
self
.
seq_length
-
1
)
input_lengths
=
ids_tensor
([
self
.
batch_size
],
vocab_size
=
2
)
+
self
.
seq_length
-
2
# small variation of seq_length
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
...
@@ -139,6 +139,8 @@ class XLMModelTest(unittest.TestCase):
...
@@ -139,6 +139,8 @@ class XLMModelTest(unittest.TestCase):
model
=
XLMModel
(
config
=
config
)
model
=
XLMModel
(
config
=
config
)
model
.
eval
()
model
.
eval
()
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
)
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
result
=
{
result
=
{
"sequence_output"
:
sequence_output
,
"sequence_output"
:
sequence_output
,
...
@@ -232,7 +234,7 @@ class XLMModelTest(unittest.TestCase):
...
@@ -232,7 +234,7 @@ class XLMModelTest(unittest.TestCase):
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'
attention_mask
'
:
input_lengths
}
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'
lengths
'
:
input_lengths
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
create_and_check_commons
(
self
,
config
,
inputs_dict
)
def
test_default
(
self
):
def
test_default
(
self
):
...
...
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
View file @
8fa3a1f0
...
@@ -140,7 +140,26 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -140,7 +140,26 @@ class XLNetModelTest(unittest.TestCase):
random
.
seed
(
self
.
seed
)
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
model
=
XLNetModel
(
config
)
model
.
eval
()
_
,
_
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
outputs
,
mems_1
=
model
(
input_ids_1
)
result
=
{
"mems_1"
:
mems_1
,
"outputs"
:
outputs
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"outputs"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
model
.
eval
()
...
@@ -150,7 +169,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -150,7 +169,7 @@ class XLNetModelTest(unittest.TestCase):
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
inp_q
=
inp_q
)
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
inp_q
=
inp_q
)
outputs
=
{
result
=
{
"loss_1"
:
loss_1
,
"loss_1"
:
loss_1
,
"mems_1"
:
mems_1
,
"mems_1"
:
mems_1
,
"all_logits_1"
:
all_logits_1
,
"all_logits_1"
:
all_logits_1
,
...
@@ -158,9 +177,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -158,9 +177,7 @@ class XLNetModelTest(unittest.TestCase):
"mems_2"
:
mems_2
,
"mems_2"
:
mems_2
,
"all_logits_2"
:
all_logits_2
,
"all_logits_2"
:
all_logits_2
,
}
}
return
outputs
def
check_transfo_xl_lm_head_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
list
(
result
[
"loss_1"
].
size
()),
[])
[])
...
@@ -203,8 +220,11 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -203,8 +220,11 @@ class XLNetModelTest(unittest.TestCase):
def
run_tester
(
self
,
tester
):
def
run_tester
(
self
,
tester
):
tester
.
set_seed
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
create_and_check_xlnet_base_model
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
tester
.
set_seed
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
...
...
pytorch_pretrained_bert/tokenization_xlm.py
View file @
8fa3a1f0
...
@@ -304,7 +304,6 @@ class XLMTokenizer(object):
...
@@ -304,7 +304,6 @@ class XLMTokenizer(object):
index
=
0
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment