Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
86caeb76
Unverified
Commit
86caeb76
authored
Feb 19, 2021
by
Julien Plu
Committed by
GitHub
Feb 19, 2021
Browse files
Fix XLA and AMP (#10262)
parent
3d72d47f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
34 deletions
+28
-34
src/transformers/models/t5/modeling_tf_t5.py
src/transformers/models/t5/modeling_tf_t5.py
+28
-18
tests/test_modeling_tf_t5.py
tests/test_modeling_tf_t5.py
+0
-16
No files found.
src/transformers/models/t5/modeling_tf_t5.py
View file @
86caeb76
...
...
@@ -169,13 +169,17 @@ class TFT5Attention(tf.keras.layers.Layer):
self
.
o
=
tf
.
keras
.
layers
.
Dense
(
self
.
d_model
,
use_bias
=
False
,
name
=
"o"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
self
.
pruned_heads
=
set
()
def
build
(
self
,
input_shape
):
if
self
.
has_relative_attention_bias
:
self
.
relative_attention_bias
=
tf
.
keras
.
layers
.
Embedding
(
self
.
relative_attention_
num_buckets
,
self
.
n_heads
,
name
=
"
relative_attention_
bias"
,
with
tf
.
name_scope
(
"
relative_attention_bias
"
):
self
.
relative_attention_
bias
=
self
.
add_weight
(
name
=
"embeddings"
,
shape
=
[
self
.
relative_attention_
num_buckets
,
self
.
n_heads
]
,
)
self
.
pruned_heads
=
set
()
return
super
().
build
(
input_shape
)
def
prune_heads
(
self
,
heads
):
raise
NotImplementedError
...
...
@@ -206,18 +210,20 @@ class TFT5Attention(tf.keras.layers.Layer):
# n = -relative_position
if
bidirectional
:
num_buckets
//=
2
relative_buckets
+=
tf
.
dtypes
.
cast
(
tf
.
math
.
greater
(
relative_position
,
0
),
tf
.
int32
)
*
num_buckets
relative_buckets
+=
(
tf
.
cast
(
tf
.
math
.
greater
(
relative_position
,
0
),
dtype
=
relative_position
.
dtype
)
*
num_buckets
)
relative_position
=
tf
.
math
.
abs
(
relative_position
)
else
:
relative_position
=
-
tf
.
math
.
minimum
(
relative_position
,
0
)
# now n is in the range [0, inf)
max_exact
=
num_buckets
//
2
is_small
=
tf
.
math
.
less
(
relative_position
,
max_exact
)
relative_position_if_large
=
max_exact
+
tf
.
dtypes
.
cast
(
tf
.
math
.
log
(
tf
.
dtypes
.
cast
(
relative_position
,
tf
.
float32
)
/
max_exact
)
relative_position_if_large
=
max_exact
+
tf
.
cast
(
tf
.
math
.
log
(
relative_position
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
(
num_buckets
-
max_exact
),
tf
.
int32
,
dtype
=
relative_position
.
dtype
,
)
relative_position_if_large
=
tf
.
math
.
minimum
(
relative_position_if_large
,
num_buckets
-
1
)
relative_buckets
+=
tf
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
...
...
@@ -233,7 +239,9 @@ class TFT5Attention(tf.keras.layers.Layer):
bidirectional
=
(
not
self
.
is_decoder
),
num_buckets
=
self
.
relative_attention_num_buckets
,
)
values
=
self
.
relative_attention_bias
(
relative_position_bucket
)
# shape (query_length, key_length, num_heads)
values
=
tf
.
gather
(
self
.
relative_attention_bias
,
relative_position_bucket
)
# shape (query_length, key_length, num_heads)
values
=
tf
.
expand_dims
(
tf
.
transpose
(
values
,
[
2
,
0
,
1
]),
axis
=
0
)
# shape (1, num_heads, query_length, key_length)
...
...
@@ -326,7 +334,7 @@ class TFT5Attention(tf.keras.layers.Layer):
if
position_bias
is
None
:
if
not
self
.
has_relative_attention_bias
:
position_bias
=
tf
.
zeros
((
1
,
self
.
n_heads
,
real_seq_length
,
key_length
)
,
dtype
=
tf
.
float32
)
position_bias
=
tf
.
zeros
((
1
,
self
.
n_heads
,
real_seq_length
,
key_length
))
else
:
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
)
...
...
@@ -336,6 +344,7 @@ class TFT5Attention(tf.keras.layers.Layer):
position_bias
=
position_bias
[:,
:,
-
seq_length
:,
:]
if
mask
is
not
None
:
position_bias
=
tf
.
cast
(
position_bias
,
dtype
=
mask
.
dtype
)
position_bias
=
position_bias
+
mask
# (batch_size, n_heads, query_length, key_length)
scores
+=
position_bias
...
...
@@ -662,7 +671,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
inputs
[
"attention_mask"
]
=
tf
.
cast
(
inputs
[
"attention_mask"
],
dtype
=
tf
.
float32
)
inputs
[
"attention_mask"
]
=
tf
.
cast
(
inputs
[
"attention_mask"
],
dtype
=
inputs
[
"inputs_embeds"
].
dtype
)
num_dims_attention_mask
=
len
(
shape_list
(
inputs
[
"attention_mask"
]))
if
num_dims_attention_mask
==
3
:
extended_attention_mask
=
inputs
[
"attention_mask"
][:,
None
,
:,
:]
...
...
@@ -676,7 +685,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
tf
.
tile
(
seq_ids
[
None
,
None
,
:],
(
batch_size
,
mask_seq_length
,
1
)),
seq_ids
[
None
,
:,
None
],
)
causal_mask
=
tf
.
cast
(
causal_mask
,
dtype
=
tf
.
float32
)
causal_mask
=
tf
.
cast
(
causal_mask
,
dtype
=
inputs
[
"attention_mask"
].
dtype
)
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
inputs
[
"attention_mask"
][:,
None
,
None
,
:]
if
inputs
[
"past_key_values"
][
0
]
is
not
None
:
extended_attention_mask
=
extended_attention_mask
[:,
:,
-
seq_length
:,
:]
...
...
@@ -700,7 +709,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs
[
"encoder_attention_mask"
]
=
tf
.
cast
(
inputs
[
"encoder_attention_mask"
],
dtype
=
tf
.
float32
)
inputs
[
"encoder_attention_mask"
]
=
tf
.
cast
(
inputs
[
"encoder_attention_mask"
],
dtype
=
extended_attention_mask
.
dtype
)
num_dims_encoder_attention_mask
=
len
(
shape_list
(
inputs
[
"encoder_attention_mask"
]))
if
num_dims_encoder_attention_mask
==
3
:
encoder_extended_attention_mask
=
inputs
[
"encoder_attention_mask"
][:,
None
,
:,
:]
...
...
@@ -868,8 +879,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
decoder_start_token_id
is
not
None
),
"self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
shifted_input_ids
=
tf
.
cast
(
input_ids
,
tf
.
int32
)
shifted_input_ids
=
tf
.
roll
(
shifted_input_ids
,
1
,
axis
=-
1
)
shifted_input_ids
=
tf
.
roll
(
input_ids
,
1
,
axis
=-
1
)
start_tokens
=
tf
.
fill
((
shape_list
(
shifted_input_ids
)[
0
],
1
),
decoder_start_token_id
)
shifted_input_ids
=
tf
.
concat
([
start_tokens
,
shifted_input_ids
[:,
1
:]],
-
1
)
...
...
@@ -880,7 +890,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
)
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
c
ast
(
0
,
tf
.
int32
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
c
onstant
(
0
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
...
...
tests/test_modeling_tf_t5.py
View file @
86caeb76
...
...
@@ -305,14 +305,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make T5 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make T5 XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
model
=
TFT5Model
.
from_pretrained
(
"t5-small"
)
...
...
@@ -442,14 +434,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_train_pipeline_custom_model
(
self
):
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make T5 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make T5 XLA compliant
pass
@
require_tf
@
require_sentencepiece
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment