Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d5610b53
Unverified
Commit
d5610b53
authored
Jul 27, 2022
by
Yanming Wang
Committed by
GitHub
Jul 27, 2022
Browse files
[XLA] Improve t5 model performance (#18288)
parent
e318cda9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
6 deletions
+2
-6
src/transformers/models/longt5/modeling_longt5.py
src/transformers/models/longt5/modeling_longt5.py
+1
-3
src/transformers/models/t5/modeling_t5.py
src/transformers/models/t5/modeling_t5.py
+1
-3
No files found.
src/transformers/models/longt5/modeling_longt5.py
View file @
d5610b53
...
...
@@ -1331,8 +1331,6 @@ class LongT5PreTrainedModel(PreTrainedModel):
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids
.
masked_fill_
(
shifted_input_ids
==
-
100
,
pad_token_id
)
assert
torch
.
all
(
shifted_input_ids
>=
0
).
item
(),
"Verify that `shifted_input_ids` has only positive values"
return
shifted_input_ids
...
...
@@ -1414,7 +1412,7 @@ class LongT5Stack(LongT5PreTrainedModel):
assert
self
.
is_decoder
,
f
"`use_cache` can only be set to `True` if
{
self
}
is used as a decoder"
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
(
batch_size
,
mask_seq_length
).
to
(
inputs_embeds
.
device
)
attention_mask
=
torch
.
ones
(
batch_size
,
mask_seq_length
,
device
=
inputs_embeds
.
device
)
if
self
.
is_decoder
and
encoder_attention_mask
is
None
and
encoder_hidden_states
is
not
None
:
encoder_seq_length
=
encoder_hidden_states
.
shape
[
1
]
encoder_attention_mask
=
torch
.
ones
(
...
...
src/transformers/models/t5/modeling_t5.py
View file @
d5610b53
...
...
@@ -827,8 +827,6 @@ class T5PreTrainedModel(PreTrainedModel):
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids
.
masked_fill_
(
shifted_input_ids
==
-
100
,
pad_token_id
)
assert
torch
.
all
(
shifted_input_ids
>=
0
).
item
(),
"Verify that `shifted_input_ids` has only positive values"
return
shifted_input_ids
...
...
@@ -944,7 +942,7 @@ class T5Stack(T5PreTrainedModel):
assert
self
.
is_decoder
,
f
"`use_cache` can only be set to `True` if
{
self
}
is used as a decoder"
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
(
batch_size
,
mask_seq_length
).
to
(
inputs_embeds
.
device
)
attention_mask
=
torch
.
ones
(
batch_size
,
mask_seq_length
,
device
=
inputs_embeds
.
device
)
if
self
.
is_decoder
and
encoder_attention_mask
is
None
and
encoder_hidden_states
is
not
None
:
encoder_seq_length
=
encoder_hidden_states
.
shape
[
1
]
encoder_attention_mask
=
torch
.
ones
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment