Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ccd1923f
Unverified
Commit
ccd1923f
authored
Jan 12, 2021
by
Suraj Patil
Committed by
GitHub
Jan 12, 2021
Browse files
[T5] enable T5 fp16 (#9487)
* fix t5 fp16
parent
2aa9c2f2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
0 deletions
+12
-0
src/transformers/models/t5/modeling_t5.py
src/transformers/models/t5/modeling_t5.py
+12
-0
No files found.
src/transformers/models/t5/modeling_t5.py
View file @
ccd1923f
...
...
@@ -640,6 +640,11 @@ class T5Block(nn.Module):
hidden_states
,
present_key_value_state
=
self_attention_outputs
[:
2
]
attention_outputs
=
self_attention_outputs
[
2
:]
# Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if
torch
.
isinf
(
hidden_states
).
any
():
clamp_value
=
torch
.
finfo
(
hidden_states
.
dtype
).
max
-
1000
hidden_states
=
torch
.
clamp
(
hidden_states
,
min
=-
clamp_value
,
max
=
clamp_value
)
do_cross_attention
=
self
.
is_decoder
and
encoder_hidden_states
is
not
None
if
do_cross_attention
:
# the actual query length is unknown for cross attention
...
...
@@ -661,6 +666,10 @@ class T5Block(nn.Module):
output_attentions
=
output_attentions
,
)
hidden_states
=
cross_attention_outputs
[
0
]
if
torch
.
isinf
(
hidden_states
).
any
():
clamp_value
=
torch
.
finfo
(
hidden_states
.
dtype
).
max
-
1000
hidden_states
=
torch
.
clamp
(
hidden_states
,
min
=-
clamp_value
,
max
=
clamp_value
)
# Combine self attn and cross attn key value states
if
present_key_value_state
is
not
None
:
present_key_value_state
=
present_key_value_state
+
cross_attention_outputs
[
1
]
...
...
@@ -670,6 +679,9 @@ class T5Block(nn.Module):
# Apply Feed Forward layer
hidden_states
=
self
.
layer
[
-
1
](
hidden_states
)
if
torch
.
isinf
(
hidden_states
).
any
():
clamp_value
=
torch
.
finfo
(
hidden_states
.
dtype
).
max
-
1000
hidden_states
=
torch
.
clamp
(
hidden_states
,
min
=-
clamp_value
,
max
=
clamp_value
)
outputs
=
(
hidden_states
,)
outputs
=
outputs
+
(
present_key_value_state
,)
+
attention_outputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment