Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
86caeb76
Unverified
Commit
86caeb76
authored
Feb 19, 2021
by
Julien Plu
Committed by
GitHub
Feb 19, 2021
Browse files
Fix XLA and AMP (#10262)
parent
3d72d47f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
34 deletions
+28
-34
src/transformers/models/t5/modeling_tf_t5.py
src/transformers/models/t5/modeling_tf_t5.py
+28
-18
tests/test_modeling_tf_t5.py
tests/test_modeling_tf_t5.py
+0
-16
No files found.
src/transformers/models/t5/modeling_tf_t5.py
View file @
86caeb76
...
@@ -169,13 +169,17 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -169,13 +169,17 @@ class TFT5Attention(tf.keras.layers.Layer):
self
.
o
=
tf
.
keras
.
layers
.
Dense
(
self
.
d_model
,
use_bias
=
False
,
name
=
"o"
)
self
.
o
=
tf
.
keras
.
layers
.
Dense
(
self
.
d_model
,
use_bias
=
False
,
name
=
"o"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
self
.
pruned_heads
=
set
()
def
build
(
self
,
input_shape
):
if
self
.
has_relative_attention_bias
:
if
self
.
has_relative_attention_bias
:
self
.
relative_attention_bias
=
tf
.
keras
.
layers
.
Embedding
(
with
tf
.
name_scope
(
"
relative_attention_bias
"
):
self
.
relative_attention_
num_buckets
,
self
.
relative_attention_
bias
=
self
.
add_weight
(
self
.
n_heads
,
name
=
"embeddings"
,
name
=
"
relative_attention_
bias"
,
shape
=
[
self
.
relative_attention_
num_buckets
,
self
.
n_heads
]
,
)
)
self
.
pruned_heads
=
set
()
return
super
().
build
(
input_shape
)
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -206,18 +210,20 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -206,18 +210,20 @@ class TFT5Attention(tf.keras.layers.Layer):
# n = -relative_position
# n = -relative_position
if
bidirectional
:
if
bidirectional
:
num_buckets
//=
2
num_buckets
//=
2
relative_buckets
+=
tf
.
dtypes
.
cast
(
tf
.
math
.
greater
(
relative_position
,
0
),
tf
.
int32
)
*
num_buckets
relative_buckets
+=
(
tf
.
cast
(
tf
.
math
.
greater
(
relative_position
,
0
),
dtype
=
relative_position
.
dtype
)
*
num_buckets
)
relative_position
=
tf
.
math
.
abs
(
relative_position
)
relative_position
=
tf
.
math
.
abs
(
relative_position
)
else
:
else
:
relative_position
=
-
tf
.
math
.
minimum
(
relative_position
,
0
)
relative_position
=
-
tf
.
math
.
minimum
(
relative_position
,
0
)
# now n is in the range [0, inf)
# now n is in the range [0, inf)
max_exact
=
num_buckets
//
2
max_exact
=
num_buckets
//
2
is_small
=
tf
.
math
.
less
(
relative_position
,
max_exact
)
is_small
=
tf
.
math
.
less
(
relative_position
,
max_exact
)
relative_position_if_large
=
max_exact
+
tf
.
dtypes
.
cast
(
relative_position_if_large
=
max_exact
+
tf
.
cast
(
tf
.
math
.
log
(
tf
.
dtypes
.
cast
(
relative_position
,
tf
.
float32
)
/
max_exact
)
tf
.
math
.
log
(
relative_position
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
(
num_buckets
-
max_exact
),
*
(
num_buckets
-
max_exact
),
tf
.
int32
,
dtype
=
relative_position
.
dtype
,
)
)
relative_position_if_large
=
tf
.
math
.
minimum
(
relative_position_if_large
,
num_buckets
-
1
)
relative_position_if_large
=
tf
.
math
.
minimum
(
relative_position_if_large
,
num_buckets
-
1
)
relative_buckets
+=
tf
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
relative_buckets
+=
tf
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
...
@@ -233,7 +239,9 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -233,7 +239,9 @@ class TFT5Attention(tf.keras.layers.Layer):
bidirectional
=
(
not
self
.
is_decoder
),
bidirectional
=
(
not
self
.
is_decoder
),
num_buckets
=
self
.
relative_attention_num_buckets
,
num_buckets
=
self
.
relative_attention_num_buckets
,
)
)
values
=
self
.
relative_attention_bias
(
relative_position_bucket
)
# shape (query_length, key_length, num_heads)
values
=
tf
.
gather
(
self
.
relative_attention_bias
,
relative_position_bucket
)
# shape (query_length, key_length, num_heads)
values
=
tf
.
expand_dims
(
values
=
tf
.
expand_dims
(
tf
.
transpose
(
values
,
[
2
,
0
,
1
]),
axis
=
0
tf
.
transpose
(
values
,
[
2
,
0
,
1
]),
axis
=
0
)
# shape (1, num_heads, query_length, key_length)
)
# shape (1, num_heads, query_length, key_length)
...
@@ -326,7 +334,7 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -326,7 +334,7 @@ class TFT5Attention(tf.keras.layers.Layer):
if
position_bias
is
None
:
if
position_bias
is
None
:
if
not
self
.
has_relative_attention_bias
:
if
not
self
.
has_relative_attention_bias
:
position_bias
=
tf
.
zeros
((
1
,
self
.
n_heads
,
real_seq_length
,
key_length
)
,
dtype
=
tf
.
float32
)
position_bias
=
tf
.
zeros
((
1
,
self
.
n_heads
,
real_seq_length
,
key_length
))
else
:
else
:
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
)
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
)
...
@@ -336,6 +344,7 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -336,6 +344,7 @@ class TFT5Attention(tf.keras.layers.Layer):
position_bias
=
position_bias
[:,
:,
-
seq_length
:,
:]
position_bias
=
position_bias
[:,
:,
-
seq_length
:,
:]
if
mask
is
not
None
:
if
mask
is
not
None
:
position_bias
=
tf
.
cast
(
position_bias
,
dtype
=
mask
.
dtype
)
position_bias
=
position_bias
+
mask
# (batch_size, n_heads, query_length, key_length)
position_bias
=
position_bias
+
mask
# (batch_size, n_heads, query_length, key_length)
scores
+=
position_bias
scores
+=
position_bias
...
@@ -662,7 +671,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
...
@@ -662,7 +671,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
# ourselves in which case we just need to make it broadcastable to all heads.
inputs
[
"attention_mask"
]
=
tf
.
cast
(
inputs
[
"attention_mask"
],
dtype
=
tf
.
float32
)
inputs
[
"attention_mask"
]
=
tf
.
cast
(
inputs
[
"attention_mask"
],
dtype
=
inputs
[
"inputs_embeds"
].
dtype
)
num_dims_attention_mask
=
len
(
shape_list
(
inputs
[
"attention_mask"
]))
num_dims_attention_mask
=
len
(
shape_list
(
inputs
[
"attention_mask"
]))
if
num_dims_attention_mask
==
3
:
if
num_dims_attention_mask
==
3
:
extended_attention_mask
=
inputs
[
"attention_mask"
][:,
None
,
:,
:]
extended_attention_mask
=
inputs
[
"attention_mask"
][:,
None
,
:,
:]
...
@@ -676,7 +685,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
...
@@ -676,7 +685,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
tf
.
tile
(
seq_ids
[
None
,
None
,
:],
(
batch_size
,
mask_seq_length
,
1
)),
tf
.
tile
(
seq_ids
[
None
,
None
,
:],
(
batch_size
,
mask_seq_length
,
1
)),
seq_ids
[
None
,
:,
None
],
seq_ids
[
None
,
:,
None
],
)
)
causal_mask
=
tf
.
cast
(
causal_mask
,
dtype
=
tf
.
float32
)
causal_mask
=
tf
.
cast
(
causal_mask
,
dtype
=
inputs
[
"attention_mask"
].
dtype
)
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
inputs
[
"attention_mask"
][:,
None
,
None
,
:]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
inputs
[
"attention_mask"
][:,
None
,
None
,
:]
if
inputs
[
"past_key_values"
][
0
]
is
not
None
:
if
inputs
[
"past_key_values"
][
0
]
is
not
None
:
extended_attention_mask
=
extended_attention_mask
[:,
:,
-
seq_length
:,
:]
extended_attention_mask
=
extended_attention_mask
[:,
:,
-
seq_length
:,
:]
...
@@ -700,7 +709,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
...
@@ -700,7 +709,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# If a 2D ou 3D attention mask is provided for the cross-attention
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs
[
"encoder_attention_mask"
]
=
tf
.
cast
(
inputs
[
"encoder_attention_mask"
],
dtype
=
tf
.
float32
)
inputs
[
"encoder_attention_mask"
]
=
tf
.
cast
(
inputs
[
"encoder_attention_mask"
],
dtype
=
extended_attention_mask
.
dtype
)
num_dims_encoder_attention_mask
=
len
(
shape_list
(
inputs
[
"encoder_attention_mask"
]))
num_dims_encoder_attention_mask
=
len
(
shape_list
(
inputs
[
"encoder_attention_mask"
]))
if
num_dims_encoder_attention_mask
==
3
:
if
num_dims_encoder_attention_mask
==
3
:
encoder_extended_attention_mask
=
inputs
[
"encoder_attention_mask"
][:,
None
,
:,
:]
encoder_extended_attention_mask
=
inputs
[
"encoder_attention_mask"
][:,
None
,
:,
:]
...
@@ -868,8 +879,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
...
@@ -868,8 +879,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
decoder_start_token_id
is
not
None
decoder_start_token_id
is
not
None
),
"self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
),
"self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
shifted_input_ids
=
tf
.
cast
(
input_ids
,
tf
.
int32
)
shifted_input_ids
=
tf
.
roll
(
input_ids
,
1
,
axis
=-
1
)
shifted_input_ids
=
tf
.
roll
(
shifted_input_ids
,
1
,
axis
=-
1
)
start_tokens
=
tf
.
fill
((
shape_list
(
shifted_input_ids
)[
0
],
1
),
decoder_start_token_id
)
start_tokens
=
tf
.
fill
((
shape_list
(
shifted_input_ids
)[
0
],
1
),
decoder_start_token_id
)
shifted_input_ids
=
tf
.
concat
([
start_tokens
,
shifted_input_ids
[:,
1
:]],
-
1
)
shifted_input_ids
=
tf
.
concat
([
start_tokens
,
shifted_input_ids
[:,
1
:]],
-
1
)
...
@@ -880,7 +890,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
...
@@ -880,7 +890,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
)
)
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
c
ast
(
0
,
tf
.
int32
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
c
onstant
(
0
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
...
...
tests/test_modeling_tf_t5.py
View file @
86caeb76
...
@@ -305,14 +305,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -305,14 +305,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make T5 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make T5 XLA compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
model
=
TFT5Model
.
from_pretrained
(
"t5-small"
)
model
=
TFT5Model
.
from_pretrained
(
"t5-small"
)
...
@@ -442,14 +434,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -442,14 +434,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_train_pipeline_custom_model
(
self
):
def
test_train_pipeline_custom_model
(
self
):
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make T5 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make T5 XLA compliant
pass
@
require_tf
@
require_tf
@
require_sentencepiece
@
require_sentencepiece
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment