Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b016dd16
Commit
b016dd16
authored
Dec 09, 2019
by
thomwolf
Browse files
fix tests on python 3.5
parent
169fea68
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
9 deletions
+10
-9
transformers/modeling_t5.py
transformers/modeling_t5.py
+1
-1
transformers/tests/modeling_common_test.py
transformers/tests/modeling_common_test.py
+8
-7
transformers/tokenization_t5.py
transformers/tokenization_t5.py
+1
-1
No files found.
transformers/modeling_t5.py
View file @
b016dd16
...
...
@@ -338,7 +338,7 @@ class T5Attention(nn.Module):
raise
ValueError
(
"No position_bias provided and no weights to compute position_bias"
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
if
mask
is
not
None
:
position_bias
+
=
mask
# (bs, n_heads, qlen, klen)
position_bias
=
position_bias
+
mask
# (bs, n_heads, qlen, klen)
scores
+=
position_bias
weights
=
F
.
softmax
(
scores
.
float
(),
dim
=-
1
).
type_as
(
scores
)
# (bs, n_heads, qlen, klen)
...
...
transformers/tests/modeling_common_test.py
View file @
b016dd16
...
...
@@ -138,8 +138,8 @@ class CommonTestCases:
self
.
assertListEqual
(
list
(
attentions
[
0
].
shape
[
-
3
:]),
[
self
.
model_tester
.
num_attention_heads
,
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
key
_len
if
hasattr
(
self
.
model_tester
,
'
key
_len'
)
else
self
.
model_tester
.
seq_length
])
self
.
model_tester
.
encoder_seq_length
if
hasattr
(
self
.
model_tester
,
'encoder_seq_length'
)
else
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
encoder_seq
_len
gth
if
hasattr
(
self
.
model_tester
,
'
encoder_seq
_len
gth
'
)
else
self
.
model_tester
.
seq_length
])
out_len
=
len
(
outputs
)
if
self
.
is_encoder_decoder
:
...
...
@@ -151,8 +151,8 @@ class CommonTestCases:
self
.
assertListEqual
(
list
(
decoder_attentions
[
0
].
shape
[
-
3
:]),
[
self
.
model_tester
.
num_attention_heads
,
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
key
_len
if
hasattr
(
self
.
model_tester
,
'
key
_len'
)
else
self
.
model_tester
.
seq_length
])
self
.
model_tester
.
decoder_seq_length
if
hasattr
(
self
.
model_tester
,
'decoder_seq_length'
)
else
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
decoder_seq
_len
gth
if
hasattr
(
self
.
model_tester
,
'
decoder_seq
_len
gth
'
)
else
self
.
model_tester
.
seq_length
])
# Check attention is always last and order is fine
config
.
output_attentions
=
True
...
...
@@ -169,8 +169,8 @@ class CommonTestCases:
self
.
assertListEqual
(
list
(
self_attentions
[
0
].
shape
[
-
3
:]),
[
self
.
model_tester
.
num_attention_heads
,
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
key
_len
if
hasattr
(
self
.
model_tester
,
'
key
_len'
)
else
self
.
model_tester
.
seq_length
])
self
.
model_tester
.
encoder_seq_length
if
hasattr
(
self
.
model_tester
,
'encoder_seq_length'
)
else
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
encoder_seq
_len
gth
if
hasattr
(
self
.
model_tester
,
'
encoder_seq
_len
gth
'
)
else
self
.
model_tester
.
seq_length
])
def
test_torchscript
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
@@ -440,7 +440,8 @@ class CommonTestCases:
self
.
assertEqual
(
len
(
hidden_states
),
self
.
model_tester
.
num_hidden_layers
+
1
)
self
.
assertListEqual
(
list
(
hidden_states
[
0
].
shape
[
-
2
:]),
[
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
hidden_size
])
[
self
.
model_tester
.
encoder_seq_length
if
hasattr
(
self
.
model_tester
,
'encoder_seq_length'
)
else
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
hidden_size
])
def
test_resize_tokens_embeddings
(
self
):
original_config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
transformers/tokenization_t5.py
View file @
b016dd16
...
...
@@ -134,7 +134,7 @@ class T5Tokenizer(PreTrainedTokenizer):
""" Converts a token (str/unicode) in an id using the vocab. """
if
token
.
startswith
(
u
"<extra_id_"
):
l
=
re
.
match
(
r
'<extra_id_(\d+)>'
,
token
)
num
=
int
(
l
[
1
]
)
num
=
int
(
l
.
group
(
1
)
)
return
self
.
vocab_size
-
num
-
1
return
self
.
sp_model
.
piece_to_id
(
token
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment