Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d5610b53
Unverified
Commit
d5610b53
authored
Jul 27, 2022
by
Yanming Wang
Committed by
GitHub
Jul 27, 2022
Browse files
[XLA] Improve t5 model performance (#18288)
parent
e318cda9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
6 deletions
+2
-6
src/transformers/models/longt5/modeling_longt5.py
src/transformers/models/longt5/modeling_longt5.py
+1
-3
src/transformers/models/t5/modeling_t5.py
src/transformers/models/t5/modeling_t5.py
+1
-3
No files found.
src/transformers/models/longt5/modeling_longt5.py
View file @
d5610b53
...
...
@@ -1331,8 +1331,6 @@ class LongT5PreTrainedModel(PreTrainedModel):
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids
.
masked_fill_
(
shifted_input_ids
==
-
100
,
pad_token_id
)
assert
torch
.
all
(
shifted_input_ids
>=
0
).
item
(),
"Verify that `shifted_input_ids` has only positive values"
return
shifted_input_ids
...
...
@@ -1414,7 +1412,7 @@ class LongT5Stack(LongT5PreTrainedModel):
assert
self
.
is_decoder
,
f
"`use_cache` can only be set to `True` if
{
self
}
is used as a decoder"
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
(
batch_size
,
mask_seq_length
).
to
(
inputs_embeds
.
device
)
attention_mask
=
torch
.
ones
(
batch_size
,
mask_seq_length
,
device
=
inputs_embeds
.
device
)
if
self
.
is_decoder
and
encoder_attention_mask
is
None
and
encoder_hidden_states
is
not
None
:
encoder_seq_length
=
encoder_hidden_states
.
shape
[
1
]
encoder_attention_mask
=
torch
.
ones
(
...
...
src/transformers/models/t5/modeling_t5.py
View file @
d5610b53
...
...
@@ -827,8 +827,6 @@ class T5PreTrainedModel(PreTrainedModel):
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids
.
masked_fill_
(
shifted_input_ids
==
-
100
,
pad_token_id
)
assert
torch
.
all
(
shifted_input_ids
>=
0
).
item
(),
"Verify that `shifted_input_ids` has only positive values"
return
shifted_input_ids
...
...
@@ -944,7 +942,7 @@ class T5Stack(T5PreTrainedModel):
assert
self
.
is_decoder
,
f
"`use_cache` can only be set to `True` if
{
self
}
is used as a decoder"
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
(
batch_size
,
mask_seq_length
).
to
(
inputs_embeds
.
device
)
attention_mask
=
torch
.
ones
(
batch_size
,
mask_seq_length
,
device
=
inputs_embeds
.
device
)
if
self
.
is_decoder
and
encoder_attention_mask
is
None
and
encoder_hidden_states
is
not
None
:
encoder_seq_length
=
encoder_hidden_states
.
shape
[
1
]
encoder_attention_mask
=
torch
.
ones
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment