Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
43891be1
Unverified
Commit
43891be1
authored
May 19, 2021
by
Patrick von Platen
Committed by
GitHub
May 19, 2021
Browse files
[T5 failing CI] Fix generate test (#11770)
* fix_torch_device_generate_test * remove @
parent
680d181c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
8 deletions
+15
-8
tests/test_generation_utils.py
tests/test_generation_utils.py
+7
-3
tests/test_modeling_t5.py
tests/test_modeling_t5.py
+8
-5
No files found.
tests/test_generation_utils.py
View file @
43891be1
...
@@ -1084,9 +1084,13 @@ class GenerationTesterMixin:
...
@@ -1084,9 +1084,13 @@ class GenerationTesterMixin:
continue
continue
head_masking
=
{
head_masking
=
{
"head_mask"
:
torch
.
zeros
(
config
.
encoder_layers
,
config
.
encoder_attention_heads
),
"head_mask"
:
torch
.
zeros
(
config
.
encoder_layers
,
config
.
encoder_attention_heads
,
device
=
torch_device
),
"decoder_head_mask"
:
torch
.
zeros
(
config
.
decoder_layers
,
config
.
decoder_attention_heads
),
"decoder_head_mask"
:
torch
.
zeros
(
"cross_attn_head_mask"
:
torch
.
zeros
(
config
.
decoder_layers
,
config
.
decoder_attention_heads
),
config
.
decoder_layers
,
config
.
decoder_attention_heads
,
device
=
torch_device
),
"cross_attn_head_mask"
:
torch
.
zeros
(
config
.
decoder_layers
,
config
.
decoder_attention_heads
,
device
=
torch_device
),
}
}
signature
=
inspect
.
signature
(
model
.
forward
)
signature
=
inspect
.
signature
(
model
.
forward
)
...
...
tests/test_modeling_t5.py
View file @
43891be1
...
@@ -605,19 +605,22 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
...
@@ -605,19 +605,22 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config
=
config_and_inputs
[
0
]
config
=
config_and_inputs
[
0
]
max_length
=
config_and_inputs
[
1
].
shape
[
-
1
]
+
3
max_length
=
config_and_inputs
[
1
].
shape
[
-
1
]
+
3
model
=
T5ForConditionalGeneration
(
config
)
model
=
T5ForConditionalGeneration
(
config
).
eval
()
model
.
to
(
torch_device
)
head_masking
=
{
head_masking
=
{
"head_mask"
:
torch
.
zeros
(
config
.
num_layers
,
config
.
num_heads
),
"head_mask"
:
torch
.
zeros
(
config
.
num_layers
,
config
.
num_heads
,
device
=
torch_device
),
"decoder_head_mask"
:
torch
.
zeros
(
config
.
num_decoder_layers
,
config
.
num_heads
),
"decoder_head_mask"
:
torch
.
zeros
(
config
.
num_decoder_layers
,
config
.
num_heads
,
device
=
torch_device
),
"cross_attn_head_mask"
:
torch
.
zeros
(
config
.
num_decoder_layers
,
config
.
num_heads
),
"cross_attn_head_mask"
:
torch
.
zeros
(
config
.
num_decoder_layers
,
config
.
num_heads
,
device
=
torch_device
),
}
}
for
attn_name
,
(
name
,
mask
)
in
zip
(
attention_names
,
head_masking
.
items
()):
for
attn_name
,
(
name
,
mask
)
in
zip
(
attention_names
,
head_masking
.
items
()):
head_masks
=
{
name
:
mask
}
head_masks
=
{
name
:
mask
}
# Explicitly pass decoder_head_mask as it is required from T5 model when head_mask specified
# Explicitly pass decoder_head_mask as it is required from T5 model when head_mask specified
if
name
==
"head_mask"
:
if
name
==
"head_mask"
:
head_masks
[
"decoder_head_mask"
]
=
torch
.
ones
(
config
.
num_decoder_layers
,
config
.
num_heads
)
head_masks
[
"decoder_head_mask"
]
=
torch
.
ones
(
config
.
num_decoder_layers
,
config
.
num_heads
,
device
=
torch_device
)
out
=
model
.
generate
(
out
=
model
.
generate
(
config_and_inputs
[
1
],
config_and_inputs
[
1
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment