Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8fa3a1f0
"...test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "cbb63c5bec618354a25583c0861f45d4a01d9812"
Commit
8fa3a1f0
authored
Jul 03, 2019
by
thomwolf
Browse files
updating tests
parent
c41f2bad
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
32 deletions
+69
-32
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+31
-23
pytorch_pretrained_bert/tests/model_tests_commons.py
pytorch_pretrained_bert/tests/model_tests_commons.py
+6
-0
pytorch_pretrained_bert/tests/modeling_bert_test.py
pytorch_pretrained_bert/tests/modeling_bert_test.py
+2
-0
pytorch_pretrained_bert/tests/modeling_xlm_test.py
pytorch_pretrained_bert/tests/modeling_xlm_test.py
+4
-2
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
+26
-6
pytorch_pretrained_bert/tokenization_xlm.py
pytorch_pretrained_bert/tokenization_xlm.py
+0
-1
No files found.
pytorch_pretrained_bert/modeling_xlm.py
View file @
8fa3a1f0
...
...
@@ -35,7 +35,7 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -46,24 +46,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json"
,
}
DECODER_ONLY_PARAMS
=
[
'layer_norm15.%i.weight'
,
'layer_norm15.%i.bias'
,
'encoder_attn.%i.q_lin.weight'
,
'encoder_attn.%i.q_lin.bias'
,
'encoder_attn.%i.k_lin.weight'
,
'encoder_attn.%i.k_lin.bias'
,
'encoder_attn.%i.v_lin.weight'
,
'encoder_attn.%i.v_lin.bias'
,
'encoder_attn.%i.out_lin.weight'
,
'encoder_attn.%i.out_lin.bias'
]
TRANSFORMER_LAYER_PARAMS
=
[
'attentions.%i.q_lin.weight'
,
'attentions.%i.q_lin.bias'
,
'attentions.%i.k_lin.weight'
,
'attentions.%i.k_lin.bias'
,
'attentions.%i.v_lin.weight'
,
'attentions.%i.v_lin.bias'
,
'attentions.%i.out_lin.weight'
,
'attentions.%i.out_lin.bias'
,
'layer_norm1.%i.weight'
,
'layer_norm1.%i.bias'
,
'ffns.%i.lin1.weight'
,
'ffns.%i.lin1.bias'
,
'ffns.%i.lin2.weight'
,
'ffns.%i.lin2.bias'
,
'layer_norm2.%i.weight'
,
'layer_norm2.%i.bias'
]
class
XLMConfig
(
PretrainedConfig
):
"""Configuration class to store the configuration of a `XLMModel`.
...
...
@@ -275,6 +257,24 @@ class MultiHeadAttention(nn.Module):
self
.
v_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
out_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
for
head
in
heads
:
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
self
.
q_lin
=
prune_linear_layer
(
self
.
q_lin
,
index
)
self
.
k_lin
=
prune_linear_layer
(
self
.
k_lin
,
index
)
self
.
v_lin
=
prune_linear_layer
(
self
.
v_lin
,
index
)
self
.
out_lin
=
prune_linear_layer
(
self
.
out_lin
,
index
,
dim
=
1
)
# Update hyper params
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
dim
=
attention_head_size
*
self
.
n_heads
def
forward
(
self
,
input
,
mask
,
kv
=
None
,
cache
=
None
,
head_mask
=
None
):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
...
...
@@ -286,9 +286,9 @@ class MultiHeadAttention(nn.Module):
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
else
:
klen
=
kv
.
size
(
1
)
assert
dim
==
self
.
dim
,
'Dimensions do not match: %s input vs %s configured'
%
(
dim
,
self
.
dim
)
#
assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads
=
self
.
n_heads
dim_per_head
=
dim
//
n_heads
dim_per_head
=
self
.
dim
//
n_heads
mask_reshape
=
(
bs
,
1
,
qlen
,
klen
)
if
mask
.
dim
()
==
3
else
(
bs
,
1
,
1
,
klen
)
def
shape
(
x
):
...
...
@@ -335,7 +335,7 @@ class MultiHeadAttention(nn.Module):
outputs
=
(
self
.
out_lin
(
context
),)
if
self
.
output_attentions
:
outputs
=
outputs
+
(
weights
)
outputs
=
outputs
+
(
weights
,
)
return
outputs
...
...
@@ -497,6 +497,14 @@ class XLMModel(XLMPreTrainedModel):
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
attentions
[
layer
].
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
head_mask
=
None
):
# src_enc=None, src_len=None,
"""
...
...
@@ -508,7 +516,7 @@ class XLMModel(XLMPreTrainedModel):
`token_type_ids` LongTensor (bs, slen) same as `langs` used for compatibility
"""
if
lengths
is
None
:
lengths
=
(
input_ids
!=
self
.
pad_index
).
float
().
sum
(
dim
=
1
)
lengths
=
(
input_ids
!=
self
.
pad_index
).
sum
(
dim
=
1
)
.
long
()
# mask = input_ids != self.pad_index
# check inputs
...
...
pytorch_pretrained_bert/tests/model_tests_commons.py
View file @
8fa3a1f0
...
...
@@ -68,6 +68,8 @@ def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict
attentions
=
outputs
[
-
1
]
hidden_states
=
outputs
[
-
2
]
# Remove Nan
tester
.
parent
.
assertIsNotNone
(
multihead_outputs
)
tester
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
tester
.
num_hidden_layers
)
tester
.
parent
.
assertAlmostEqual
(
...
...
@@ -298,7 +300,11 @@ class GPTModelTester(object):
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
self
.
base_model_class
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
model
(
input_ids
,
position_ids
)
outputs
=
model
(
input_ids
)
hidden_state
=
outputs
[
0
]
self
.
parent
.
assertListEqual
(
list
(
hidden_state
.
size
()),
...
...
pytorch_pretrained_bert/tests/modeling_bert_test.py
View file @
8fa3a1f0
...
...
@@ -126,6 +126,8 @@ class BertModelTest(unittest.TestCase):
model
=
BertModel
(
config
=
config
)
model
.
eval
()
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
sequence_output
,
pooled_output
=
model
(
input_ids
,
token_type_ids
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
...
...
pytorch_pretrained_bert/tests/modeling_xlm_test.py
View file @
8fa3a1f0
...
...
@@ -96,7 +96,7 @@ class XLMModelTest(unittest.TestCase):
input_lengths
=
None
if
self
.
use_input_lengths
:
input_lengths
=
ids_tensor
([
self
.
batch_size
],
vocab_size
=
self
.
seq_length
-
1
)
input_lengths
=
ids_tensor
([
self
.
batch_size
],
vocab_size
=
2
)
+
self
.
seq_length
-
2
# small variation of seq_length
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
@@ -139,6 +139,8 @@ class XLMModelTest(unittest.TestCase):
model
=
XLMModel
(
config
=
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
)
sequence_output
=
outputs
[
0
]
result
=
{
"sequence_output"
:
sequence_output
,
...
...
@@ -232,7 +234,7 @@ class XLMModelTest(unittest.TestCase):
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'
attention_mask
'
:
input_lengths
}
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'
lengths
'
:
input_lengths
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
def
test_default
(
self
):
...
...
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
View file @
8fa3a1f0
...
...
@@ -140,7 +140,26 @@ class XLNetModelTest(unittest.TestCase):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
model
=
XLNetModel
(
config
)
model
.
eval
()
_
,
_
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
outputs
,
mems_1
=
model
(
input_ids_1
)
result
=
{
"mems_1"
:
mems_1
,
"outputs"
:
outputs
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"outputs"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
...
...
@@ -150,7 +169,7 @@ class XLNetModelTest(unittest.TestCase):
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
inp_q
=
inp_q
)
outputs
=
{
result
=
{
"loss_1"
:
loss_1
,
"mems_1"
:
mems_1
,
"all_logits_1"
:
all_logits_1
,
...
...
@@ -158,9 +177,7 @@ class XLNetModelTest(unittest.TestCase):
"mems_2"
:
mems_2
,
"all_logits_2"
:
all_logits_2
,
}
return
outputs
def
check_transfo_xl_lm_head_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[])
...
...
@@ -203,8 +220,11 @@ class XLNetModelTest(unittest.TestCase):
def
run_tester
(
self
,
tester
):
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
tester
.
create_and_check_xlnet_base_model
(
*
config_and_inputs
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
...
...
pytorch_pretrained_bert/tokenization_xlm.py
View file @
8fa3a1f0
...
...
@@ -304,7 +304,6 @@ class XLMTokenizer(object):
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment