Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
86caeb76
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dbac8899fe49275a794e6be36fa32662e14fb6bc"
Unverified
Commit
86caeb76
authored
Feb 19, 2021
by
Julien Plu
Committed by
GitHub
Feb 19, 2021
Browse files
Fix XLA and AMP (#10262)
parent
3d72d47f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
34 deletions
+28
-34
src/transformers/models/t5/modeling_tf_t5.py
src/transformers/models/t5/modeling_tf_t5.py
+28
-18
tests/test_modeling_tf_t5.py
tests/test_modeling_tf_t5.py
+0
-16
No files found.
src/transformers/models/t5/modeling_tf_t5.py
View file @
86caeb76
...
@@ -169,14 +169,18 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -169,14 +169,18 @@ class TFT5Attention(tf.keras.layers.Layer):
self
.
o
=
tf
.
keras
.
layers
.
Dense
(
self
.
d_model
,
use_bias
=
False
,
name
=
"o"
)
self
.
o
=
tf
.
keras
.
layers
.
Dense
(
self
.
d_model
,
use_bias
=
False
,
name
=
"o"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
if
self
.
has_relative_attention_bias
:
self
.
relative_attention_bias
=
tf
.
keras
.
layers
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
n_heads
,
name
=
"relative_attention_bias"
,
)
self
.
pruned_heads
=
set
()
self
.
pruned_heads
=
set
()
def
build
(
self
,
input_shape
):
if
self
.
has_relative_attention_bias
:
with
tf
.
name_scope
(
"relative_attention_bias"
):
self
.
relative_attention_bias
=
self
.
add_weight
(
name
=
"embeddings"
,
shape
=
[
self
.
relative_attention_num_buckets
,
self
.
n_heads
],
)
return
super
().
build
(
input_shape
)
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -206,18 +210,20 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -206,18 +210,20 @@ class TFT5Attention(tf.keras.layers.Layer):
# n = -relative_position
# n = -relative_position
if
bidirectional
:
if
bidirectional
:
num_buckets
//=
2
num_buckets
//=
2
relative_buckets
+=
tf
.
dtypes
.
cast
(
tf
.
math
.
greater
(
relative_position
,
0
),
tf
.
int32
)
*
num_buckets
relative_buckets
+=
(
tf
.
cast
(
tf
.
math
.
greater
(
relative_position
,
0
),
dtype
=
relative_position
.
dtype
)
*
num_buckets
)
relative_position
=
tf
.
math
.
abs
(
relative_position
)
relative_position
=
tf
.
math
.
abs
(
relative_position
)
else
:
else
:
relative_position
=
-
tf
.
math
.
minimum
(
relative_position
,
0
)
relative_position
=
-
tf
.
math
.
minimum
(
relative_position
,
0
)
# now n is in the range [0, inf)
# now n is in the range [0, inf)
max_exact
=
num_buckets
//
2
max_exact
=
num_buckets
//
2
is_small
=
tf
.
math
.
less
(
relative_position
,
max_exact
)
is_small
=
tf
.
math
.
less
(
relative_position
,
max_exact
)
relative_position_if_large
=
max_exact
+
tf
.
dtypes
.
cast
(
relative_position_if_large
=
max_exact
+
tf
.
cast
(
tf
.
math
.
log
(
tf
.
dtypes
.
cast
(
relative_position
,
tf
.
float32
)
/
max_exact
)
tf
.
math
.
log
(
relative_position
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
(
num_buckets
-
max_exact
),
*
(
num_buckets
-
max_exact
),
tf
.
int32
,
dtype
=
relative_position
.
dtype
,
)
)
relative_position_if_large
=
tf
.
math
.
minimum
(
relative_position_if_large
,
num_buckets
-
1
)
relative_position_if_large
=
tf
.
math
.
minimum
(
relative_position_if_large
,
num_buckets
-
1
)
relative_buckets
+=
tf
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
relative_buckets
+=
tf
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
...
@@ -233,7 +239,9 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -233,7 +239,9 @@ class TFT5Attention(tf.keras.layers.Layer):
bidirectional
=
(
not
self
.
is_decoder
),
bidirectional
=
(
not
self
.
is_decoder
),
num_buckets
=
self
.
relative_attention_num_buckets
,
num_buckets
=
self
.
relative_attention_num_buckets
,
)
)
values
=
self
.
relative_attention_bias
(
relative_position_bucket
)
# shape (query_length, key_length, num_heads)
values
=
tf
.
gather
(
self
.
relative_attention_bias
,
relative_position_bucket
)
# shape (query_length, key_length, num_heads)
values
=
tf
.
expand_dims
(
values
=
tf
.
expand_dims
(
tf
.
transpose
(
values
,
[
2
,
0
,
1
]),
axis
=
0
tf
.
transpose
(
values
,
[
2
,
0
,
1
]),
axis
=
0
)
# shape (1, num_heads, query_length, key_length)
)
# shape (1, num_heads, query_length, key_length)
...
@@ -326,7 +334,7 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -326,7 +334,7 @@ class TFT5Attention(tf.keras.layers.Layer):
if
position_bias
is
None
:
if
position_bias
is
None
:
if
not
self
.
has_relative_attention_bias
:
if
not
self
.
has_relative_attention_bias
:
position_bias
=
tf
.
zeros
((
1
,
self
.
n_heads
,
real_seq_length
,
key_length
)
,
dtype
=
tf
.
float32
)
position_bias
=
tf
.
zeros
((
1
,
self
.
n_heads
,
real_seq_length
,
key_length
))
else
:
else
:
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
)
position_bias
=
self
.
compute_bias
(
real_seq_length
,
key_length
)
...
@@ -336,6 +344,7 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -336,6 +344,7 @@ class TFT5Attention(tf.keras.layers.Layer):
position_bias
=
position_bias
[:,
:,
-
seq_length
:,
:]
position_bias
=
position_bias
[:,
:,
-
seq_length
:,
:]
if
mask
is
not
None
:
if
mask
is
not
None
:
position_bias
=
tf
.
cast
(
position_bias
,
dtype
=
mask
.
dtype
)
position_bias
=
position_bias
+
mask
# (batch_size, n_heads, query_length, key_length)
position_bias
=
position_bias
+
mask
# (batch_size, n_heads, query_length, key_length)
scores
+=
position_bias
scores
+=
position_bias
...
@@ -662,7 +671,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
...
@@ -662,7 +671,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
# ourselves in which case we just need to make it broadcastable to all heads.
inputs
[
"attention_mask"
]
=
tf
.
cast
(
inputs
[
"attention_mask"
],
dtype
=
tf
.
float32
)
inputs
[
"attention_mask"
]
=
tf
.
cast
(
inputs
[
"attention_mask"
],
dtype
=
inputs
[
"inputs_embeds"
].
dtype
)
num_dims_attention_mask
=
len
(
shape_list
(
inputs
[
"attention_mask"
]))
num_dims_attention_mask
=
len
(
shape_list
(
inputs
[
"attention_mask"
]))
if
num_dims_attention_mask
==
3
:
if
num_dims_attention_mask
==
3
:
extended_attention_mask
=
inputs
[
"attention_mask"
][:,
None
,
:,
:]
extended_attention_mask
=
inputs
[
"attention_mask"
][:,
None
,
:,
:]
...
@@ -676,7 +685,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
...
@@ -676,7 +685,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
tf
.
tile
(
seq_ids
[
None
,
None
,
:],
(
batch_size
,
mask_seq_length
,
1
)),
tf
.
tile
(
seq_ids
[
None
,
None
,
:],
(
batch_size
,
mask_seq_length
,
1
)),
seq_ids
[
None
,
:,
None
],
seq_ids
[
None
,
:,
None
],
)
)
causal_mask
=
tf
.
cast
(
causal_mask
,
dtype
=
tf
.
float32
)
causal_mask
=
tf
.
cast
(
causal_mask
,
dtype
=
inputs
[
"attention_mask"
].
dtype
)
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
inputs
[
"attention_mask"
][:,
None
,
None
,
:]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
inputs
[
"attention_mask"
][:,
None
,
None
,
:]
if
inputs
[
"past_key_values"
][
0
]
is
not
None
:
if
inputs
[
"past_key_values"
][
0
]
is
not
None
:
extended_attention_mask
=
extended_attention_mask
[:,
:,
-
seq_length
:,
:]
extended_attention_mask
=
extended_attention_mask
[:,
:,
-
seq_length
:,
:]
...
@@ -700,7 +709,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
...
@@ -700,7 +709,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# If a 2D ou 3D attention mask is provided for the cross-attention
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs
[
"encoder_attention_mask"
]
=
tf
.
cast
(
inputs
[
"encoder_attention_mask"
],
dtype
=
tf
.
float32
)
inputs
[
"encoder_attention_mask"
]
=
tf
.
cast
(
inputs
[
"encoder_attention_mask"
],
dtype
=
extended_attention_mask
.
dtype
)
num_dims_encoder_attention_mask
=
len
(
shape_list
(
inputs
[
"encoder_attention_mask"
]))
num_dims_encoder_attention_mask
=
len
(
shape_list
(
inputs
[
"encoder_attention_mask"
]))
if
num_dims_encoder_attention_mask
==
3
:
if
num_dims_encoder_attention_mask
==
3
:
encoder_extended_attention_mask
=
inputs
[
"encoder_attention_mask"
][:,
None
,
:,
:]
encoder_extended_attention_mask
=
inputs
[
"encoder_attention_mask"
][:,
None
,
:,
:]
...
@@ -868,8 +879,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
...
@@ -868,8 +879,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
decoder_start_token_id
is
not
None
decoder_start_token_id
is
not
None
),
"self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
),
"self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
shifted_input_ids
=
tf
.
cast
(
input_ids
,
tf
.
int32
)
shifted_input_ids
=
tf
.
roll
(
input_ids
,
1
,
axis
=-
1
)
shifted_input_ids
=
tf
.
roll
(
shifted_input_ids
,
1
,
axis
=-
1
)
start_tokens
=
tf
.
fill
((
shape_list
(
shifted_input_ids
)[
0
],
1
),
decoder_start_token_id
)
start_tokens
=
tf
.
fill
((
shape_list
(
shifted_input_ids
)[
0
],
1
),
decoder_start_token_id
)
shifted_input_ids
=
tf
.
concat
([
start_tokens
,
shifted_input_ids
[:,
1
:]],
-
1
)
shifted_input_ids
=
tf
.
concat
([
start_tokens
,
shifted_input_ids
[:,
1
:]],
-
1
)
...
@@ -880,7 +890,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
...
@@ -880,7 +890,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
)
)
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
c
ast
(
0
,
tf
.
int32
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
c
onstant
(
0
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
...
...
tests/test_modeling_tf_t5.py
View file @
86caeb76
...
@@ -305,14 +305,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -305,14 +305,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make T5 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make T5 XLA compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
model
=
TFT5Model
.
from_pretrained
(
"t5-small"
)
model
=
TFT5Model
.
from_pretrained
(
"t5-small"
)
...
@@ -442,14 +434,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -442,14 +434,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_train_pipeline_custom_model
(
self
):
def
test_train_pipeline_custom_model
(
self
):
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make T5 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make T5 XLA compliant
pass
@
require_tf
@
require_tf
@
require_sentencepiece
@
require_sentencepiece
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment