Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
026a5d08
Unverified
Commit
026a5d08
authored
May 18, 2020
by
Patrick von Platen
Committed by
GitHub
May 18, 2020
Browse files
[T5 fp16] Fix fp16 in T5 (#4436)
* fix fp16 in t5 * make style * refactor invert_attention_mask fn * fix typo
parent
fa6113f9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
3 deletions
+36
-3
src/transformers/modeling_t5.py
src/transformers/modeling_t5.py
+9
-2
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+12
-1
tests/test_modeling_t5.py
tests/test_modeling_t5.py
+15
-0
No files found.
src/transformers/modeling_t5.py
View file @
026a5d08
...
...
@@ -149,8 +149,12 @@ class T5LayerNorm(nn.Module):
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
):
variance
=
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
# layer norm should always be calculated in float32
variance
=
x
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x
=
x
/
torch
.
sqrt
(
variance
+
self
.
variance_epsilon
)
if
self
.
weight
.
dtype
==
torch
.
float16
:
x
=
x
.
to
(
torch
.
float16
)
return
self
.
weight
*
x
...
...
@@ -691,7 +695,9 @@ class T5Stack(T5PreTrainedModel):
attention_mask
=
torch
.
ones
(
batch_size
,
mask_seq_length
).
to
(
inputs_embeds
.
device
)
if
self
.
is_decoder
and
encoder_attention_mask
is
None
and
encoder_hidden_states
is
not
None
:
encoder_seq_length
=
encoder_hidden_states
.
shape
[
1
]
encoder_attention_mask
=
torch
.
ones
(
batch_size
,
encoder_seq_length
).
to
(
inputs_embeds
.
device
)
encoder_attention_mask
=
torch
.
ones
(
batch_size
,
encoder_seq_length
,
device
=
inputs_embeds
.
device
,
dtype
=
torch
.
long
)
# initialize past_key_value_states with `None` if past does not exist
if
past_key_value_states
is
None
:
...
...
@@ -733,6 +739,7 @@ class T5Stack(T5PreTrainedModel):
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states
,
present_key_value_state
=
layer_outputs
[:
2
]
if
i
==
0
:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
...
...
src/transformers/modeling_utils.py
View file @
026a5d08
...
...
@@ -128,7 +128,18 @@ class ModuleUtilsMixin:
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
self
.
dtype
)
# fp16 compatibility
if
self
.
dtype
==
torch
.
float16
:
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e4
elif
self
.
dtype
==
torch
.
float32
:
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e9
else
:
raise
ValueError
(
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`"
.
format
(
self
.
dtype
)
)
return
encoder_extended_attention_mask
def
get_extended_attention_mask
(
self
,
attention_mask
:
Tensor
,
input_shape
:
tuple
,
device
:
device
):
...
...
tests/test_modeling_t5.py
View file @
026a5d08
...
...
@@ -304,6 +304,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
output_with_past_cache
=
model
.
generate
(
input_ids
[:
1
],
num_beams
=
2
,
max_length
=
5
,
do_sample
=
True
)
self
.
parent
.
assertTrue
(
torch
.
all
(
output_with_past_cache
==
output_without_past_cache
))
def
create_and_check_t5_model_fp16_forward
(
self
,
config
,
input_ids
,
decoder_input_ids
,
attention_mask
,
decoder_attention_mask
,
lm_labels
,
):
model
=
T5Model
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
half
()
model
.
eval
()
output
=
model
(
input_ids
,
decoder_input_ids
=
input_ids
,
attention_mask
=
attention_mask
)[
0
]
self
.
parent
.
assertFalse
(
torch
.
isnan
(
output
).
any
().
item
())
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
...
...
@@ -355,6 +365,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_t5_and_check_t5_generate_with_past_key_value_states
(
*
config_and_inputs
)
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Cant do half precision"
)
def
test_t5_model_fp16_forward
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_t5_model_fp16_forward
(
*
config_and_inputs
)
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
list
(
T5_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment