Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
31be02f1
Unverified
Commit
31be02f1
authored
Sep 14, 2022
by
Joao Gante
Committed by
GitHub
Sep 14, 2022
Browse files
TF: tf.debugging assertions without tf.running_eagerly() protection (#19030)
parent
693ba2cc
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
696 additions
and
947 deletions
+696
-947
src/transformers/models/bart/modeling_tf_bart.py
src/transformers/models/bart/modeling_tf_bart.py
+42
-62
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
+42
-62
src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
...s/models/blenderbot_small/modeling_tf_blenderbot_small.py
+42
-62
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
...ers/models/encoder_decoder/modeling_tf_encoder_decoder.py
+5
-6
src/transformers/models/flaubert/modeling_tf_flaubert.py
src/transformers/models/flaubert/modeling_tf_flaubert.py
+12
-14
src/transformers/models/hubert/modeling_tf_hubert.py
src/transformers/models/hubert/modeling_tf_hubert.py
+30
-42
src/transformers/models/led/modeling_tf_led.py
src/transformers/models/led/modeling_tf_led.py
+130
-151
src/transformers/models/longformer/modeling_tf_longformer.py
src/transformers/models/longformer/modeling_tf_longformer.py
+89
-104
src/transformers/models/marian/modeling_tf_marian.py
src/transformers/models/marian/modeling_tf_marian.py
+42
-62
src/transformers/models/mbart/modeling_tf_mbart.py
src/transformers/models/mbart/modeling_tf_mbart.py
+37
-56
src/transformers/models/opt/modeling_tf_opt.py
src/transformers/models/opt/modeling_tf_opt.py
+31
-45
src/transformers/models/pegasus/modeling_tf_pegasus.py
src/transformers/models/pegasus/modeling_tf_pegasus.py
+42
-62
src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
...rmers/models/speech_to_text/modeling_tf_speech_to_text.py
+42
-60
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
...ion_encoder_decoder/modeling_tf_vision_encoder_decoder.py
+5
-6
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
+30
-42
src/transformers/models/xglm/modeling_tf_xglm.py
src/transformers/models/xglm/modeling_tf_xglm.py
+31
-45
src/transformers/models/xlm/modeling_tf_xlm.py
src/transformers/models/xlm/modeling_tf_xlm.py
+12
-14
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
...ame}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
+32
-52
No files found.
src/transformers/models/bart/modeling_tf_bart.py
View file @
31be02f1
...
@@ -71,13 +71,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -71,13 +71,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -229,31 +228,25 @@ class TFBartAttention(tf.keras.layers.Layer):
...
@@ -229,31 +228,25 @@ class TFBartAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -261,17 +254,14 @@ class TFBartAttention(tf.keras.layers.Layer):
...
@@ -261,17 +254,14 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -281,17 +271,14 @@ class TFBartAttention(tf.keras.layers.Layer):
...
@@ -281,17 +271,14 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -339,14 +326,11 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
...
@@ -339,14 +326,11 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -776,9 +760,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
...
@@ -776,9 +760,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -983,10 +965,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
...
@@ -983,10 +965,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
View file @
31be02f1
...
@@ -73,13 +73,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -73,13 +73,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -225,31 +224,25 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
...
@@ -225,31 +224,25 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -257,17 +250,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
...
@@ -257,17 +250,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -277,17 +267,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
...
@@ -277,17 +267,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -337,14 +324,11 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
...
@@ -337,14 +324,11 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -755,9 +739,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
...
@@ -755,9 +739,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -966,10 +948,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
...
@@ -966,10 +948,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
View file @
31be02f1
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -225,31 +224,25 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
...
@@ -225,31 +224,25 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -257,17 +250,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
...
@@ -257,17 +250,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -277,17 +267,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
...
@@ -277,17 +267,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -336,14 +323,11 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
...
@@ -336,14 +323,11 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -761,9 +745,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
...
@@ -761,9 +745,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -968,10 +950,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
...
@@ -968,10 +950,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
View file @
31be02f1
...
@@ -171,13 +171,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -171,13 +171,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
...
src/transformers/models/flaubert/modeling_tf_flaubert.py
View file @
31be02f1
...
@@ -200,9 +200,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
...
@@ -200,9 +200,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
# sanity check
# sanity check
# assert shape_list(mask) == [bs, slen]
# assert shape_list(mask) == [bs, slen]
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
mask
),
[
bs
,
slen
])
tf
.
debugging
.
assert_equal
(
shape_list
(
mask
),
[
bs
,
slen
])
if
causal
:
assert
causal
is
False
or
shape_list
(
attn_mask
)
==
[
bs
,
slen
,
slen
]
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)
,
[
bs
,
slen
,
slen
]
)
return
mask
,
attn_mask
return
mask
,
attn_mask
...
@@ -517,10 +517,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
...
@@ -517,10 +517,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# check inputs
# check inputs
# assert shape_list(lengths)[0] == bs
# assert shape_list(lengths)[0] == bs
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
lengths
)[
0
],
bs
shape_list
(
lengths
)[
0
],
bs
),
f
"Expected batch size
{
shape_list
(
lengths
)[
0
]
}
and received batch size
{
bs
}
mismatched"
),
f
"Expected batch size
{
shape_list
(
lengths
)[
0
]
}
and received batch size
{
bs
}
mismatched"
# assert lengths.max().item() <= slen
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
# assert (src_enc is None) == (src_len is None)
...
@@ -538,15 +537,14 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
...
@@ -538,15 +537,14 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
position_ids
=
tf
.
expand_dims
(
tf
.
range
(
slen
),
axis
=
0
)
position_ids
=
tf
.
expand_dims
(
tf
.
range
(
slen
),
axis
=
0
)
position_ids
=
tf
.
tile
(
position_ids
,
(
bs
,
1
))
position_ids
=
tf
.
tile
(
position_ids
,
(
bs
,
1
))
if
tf
.
executing_eagerly
():
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
position_ids
),
[
bs
,
slen
]
shape_list
(
position_ids
),
[
bs
,
slen
]
),
f
"Position id shape
{
shape_list
(
position_ids
)
}
and input shape
{
[
bs
,
slen
]
}
mismatched"
),
f
"Position id shape
{
shape_list
(
position_ids
)
}
and input shape
{
[
bs
,
slen
]
}
mismatched"
# position_ids = position_ids.transpose(0, 1)
# position_ids = position_ids.transpose(0, 1)
# langs
# langs
if
langs
is
not
None
and
tf
.
executing_eagerly
()
:
if
langs
is
not
None
:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
langs
),
[
bs
,
slen
]
shape_list
(
langs
),
[
bs
,
slen
]
...
...
src/transformers/models/hubert/modeling_tf_hubert.py
View file @
31be02f1
...
@@ -816,31 +816,25 @@ class TFHubertAttention(tf.keras.layers.Layer):
...
@@ -816,31 +816,25 @@ class TFHubertAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -848,17 +842,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
...
@@ -848,17 +842,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -868,17 +859,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
...
@@ -868,17 +859,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
...
src/transformers/models/led/modeling_tf_led.py
View file @
31be02f1
...
@@ -64,12 +64,11 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -64,12 +64,11 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
)
)
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
if
tf
.
executing_eagerly
():
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -213,12 +212,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -213,12 +212,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
value_vectors
=
self
.
value
(
hidden_states
)
value_vectors
=
self
.
value
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
embed_dim
,
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
message
=
f
"hidden_states should have embed_dim =
{
self
.
embed_dim
}
, but has
{
embed_dim
}
"
,
message
=
f
"hidden_states should have embed_dim =
{
self
.
embed_dim
}
, but has
{
embed_dim
}
"
,
)
)
# normalize query
# normalize query
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
...
@@ -245,15 +243,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -245,15 +243,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
# pad local attention probs
attn_scores
+=
diagonal_mask
attn_scores
+=
diagonal_mask
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_scores
),
shape_list
(
attn_scores
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
message
=
(
message
=
(
f
"attn_probs should be of size (
{
batch_size
}
,
{
seq_len
}
,
{
self
.
num_heads
}
,"
f
"attn_probs should be of size (
{
batch_size
}
,
{
seq_len
}
,
{
self
.
num_heads
}
,"
f
"
{
self
.
one_sided_attn_window_size
*
2
+
1
}
), but is of size
{
shape_list
(
attn_scores
)
}
"
f
"
{
self
.
one_sided_attn_window_size
*
2
+
1
}
), but is of size
{
shape_list
(
attn_scores
)
}
"
),
),
)
)
# compute global attn indices required through out forward fn
# compute global attn indices required through out forward fn
(
(
...
@@ -301,15 +298,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -301,15 +298,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
(
message
=
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"
{
shape_list
(
layer_head_mask
)
}
"
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
),
)
)
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
...
@@ -332,12 +328,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -332,12 +328,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
),
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
message
=
"Unexpected size"
shape_list
(
attn_output
),
)
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
message
=
"Unexpected size"
,
)
attn_output
=
tf
.
reshape
(
attn_output
,
(
batch_size
,
seq_len
,
embed_dim
))
attn_output
=
tf
.
reshape
(
attn_output
,
(
batch_size
,
seq_len
,
embed_dim
))
...
@@ -392,20 +385,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -392,20 +385,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
"""
"""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
seq_len
%
(
window_overlap
*
2
),
0
,
0
,
message
=
f
"Sequence length should be multiple of
{
window_overlap
*
2
}
. Given
{
seq_len
}
"
,
message
=
f
"Sequence length should be multiple of
{
window_overlap
*
2
}
. Given
{
seq_len
}
"
,
)
)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
query
),
shape_list
(
query
),
shape_list
(
key
),
shape_list
(
key
),
message
=
(
message
=
(
f
"Shape of query and key should be equal, but got query:
{
shape_list
(
query
)
}
and key:"
f
"Shape of query and key should be equal, but got query:
{
shape_list
(
query
)
}
and key:"
f
"
{
shape_list
(
key
)
}
"
f
"
{
shape_list
(
key
)
}
"
),
),
)
)
chunks_count
=
seq_len
//
window_overlap
-
1
chunks_count
=
seq_len
//
window_overlap
-
1
...
@@ -539,22 +531,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -539,22 +531,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
message
=
"Seq_len has to be multiple of 2 * window_overlap"
seq_len
%
(
window_overlap
*
2
),
)
0
,
tf
.
debugging
.
assert_equal
(
message
=
"Seq_len has to be multiple of 2 * window_overlap"
,
shape_list
(
attn_probs
)[:
3
],
)
shape_list
(
value
)[:
3
],
tf
.
debugging
.
assert_equal
(
message
=
"value and attn_probs must have same dims (except head_dim)"
,
shape_list
(
attn_probs
)[:
3
],
)
shape_list
(
value
)[:
3
],
tf
.
debugging
.
assert_equal
(
message
=
"value and attn_probs must have same dims (except head_dim)"
,
shape_list
(
attn_probs
)[
3
],
)
2
*
window_overlap
+
1
,
tf
.
debugging
.
assert_equal
(
message
=
"attn_probs last dim has to be 2 * window_overlap + 1"
,
shape_list
(
attn_probs
)[
3
],
)
2
*
window_overlap
+
1
,
message
=
"attn_probs last dim has to be 2 * window_overlap + 1"
,
)
chunks_count
=
seq_len
//
window_overlap
-
1
chunks_count
=
seq_len
//
window_overlap
-
1
...
@@ -592,12 +581,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -592,12 +581,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_value
),
shape_list
(
chunked_value
),
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
message
=
"Chunked value has the wrong shape"
,
message
=
"Chunked value has the wrong shape"
,
)
)
chunked_attn_probs
=
self
.
_pad_and_diagonalize
(
chunked_attn_probs
)
chunked_attn_probs
=
self
.
_pad_and_diagonalize
(
chunked_attn_probs
)
context
=
tf
.
einsum
(
"bcwd,bcdh->bcwh"
,
chunked_attn_probs
,
chunked_value
)
context
=
tf
.
einsum
(
"bcwd,bcdh->bcwh"
,
chunked_attn_probs
,
chunked_value
)
...
@@ -685,15 +673,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -685,15 +673,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
# chunk with overlap
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_hidden_states
),
shape_list
(
chunked_hidden_states
),
[
batch_size
,
num_output_chunks
,
frame_size
],
[
batch_size
,
num_output_chunks
,
frame_size
],
message
=
(
message
=
(
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
f
"
{
[
batch_size
,
frame_size
,
num_output_chunks
]
}
, but got
{
shape_list
(
chunked_hidden_states
)
}
."
f
"
{
[
batch_size
,
frame_size
,
num_output_chunks
]
}
, but got
{
shape_list
(
chunked_hidden_states
)
}
."
),
),
)
)
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
,
chunked_hidden_states
,
...
@@ -866,16 +853,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -866,16 +853,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# compute attn scores
# compute attn scores
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_scores
),
shape_list
(
global_attn_scores
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
message
=
(
message
=
(
"global_attn_scores have the wrong size. Size should be"
"global_attn_scores have the wrong size. Size should be"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
}
, but is"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
}
, but is"
f
"
{
shape_list
(
global_attn_scores
)
}
."
f
"
{
shape_list
(
global_attn_scores
)
}
."
),
),
)
)
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
,
global_attn_scores
,
...
@@ -909,15 +895,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -909,15 +895,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# apply layer head masking
# apply layer head masking
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
(
message
=
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"
{
shape_list
(
layer_head_mask
)
}
"
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
),
)
)
global_attn_probs_float
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
global_attn_probs_float
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
global_attn_probs_float
,
(
batch_size
,
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
global_attn_probs_float
,
(
batch_size
,
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
)
)
...
@@ -931,16 +916,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -931,16 +916,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# global attn output
# global attn output
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_output
),
shape_list
(
global_attn_output
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
message
=
(
message
=
(
"global_attn_output tensor has the wrong size. Size should be"
"global_attn_output tensor has the wrong size. Size should be"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
)
}
, but is"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
)
}
, but is"
f
"
{
shape_list
(
global_attn_output
)
}
."
f
"
{
shape_list
(
global_attn_output
)
}
."
),
),
)
)
global_attn_output
=
tf
.
reshape
(
global_attn_output
=
tf
.
reshape
(
global_attn_output
,
global_attn_output
,
...
@@ -1091,27 +1075,25 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
...
@@ -1091,27 +1075,25 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_weights
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
tf
.
cast
(
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
attention_mask
,
dtype
=
attn_weights
.
dtype
)
)
...
@@ -1120,15 +1102,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
...
@@ -1120,15 +1102,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
(
message
=
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"
{
shape_list
(
layer_head_mask
)
}
"
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
),
)
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -1139,15 +1120,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
...
@@ -1139,15 +1120,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
shape_list
(
attn_output
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
message
=
(
message
=
(
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
f
"
{
shape_list
(
attn_output
)
}
"
f
"
{
shape_list
(
attn_output
)
}
"
),
),
)
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -1199,12 +1179,11 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
...
@@ -1199,12 +1179,11 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
hidden_states
),
shape_list
(
hidden_states
),
shape_list
(
residual
),
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -1792,7 +1771,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
...
@@ -1792,7 +1771,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
all_attentions
=
all_global_attentions
=
()
if
output_attentions
else
None
all_attentions
=
all_global_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
if
head_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
head_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -2055,7 +2034,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
...
@@ -2055,7 +2034,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
present_key_values
=
()
present_key_values
=
()
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
if
head_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
head_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/longformer/modeling_tf_longformer.py
View file @
31be02f1
...
@@ -738,12 +738,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -738,12 +738,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors
=
self
.
value
(
hidden_states
)
value_vectors
=
self
.
value
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
embed_dim
,
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
message
=
f
"hidden_states should have embed_dim =
{
self
.
embed_dim
}
, but has
{
embed_dim
}
"
,
message
=
f
"hidden_states should have embed_dim =
{
self
.
embed_dim
}
, but has
{
embed_dim
}
"
,
)
)
# normalize query
# normalize query
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
...
@@ -770,15 +769,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -770,15 +769,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
# pad local attention probs
attn_scores
+=
diagonal_mask
attn_scores
+=
diagonal_mask
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_scores
),
shape_list
(
attn_scores
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
message
=
(
message
=
(
f
"attn_probs should be of size (
{
batch_size
}
,
{
seq_len
}
,
{
self
.
num_heads
}
,"
f
"attn_probs should be of size (
{
batch_size
}
,
{
seq_len
}
,
{
self
.
num_heads
}
,"
f
"
{
self
.
one_sided_attn_window_size
*
2
+
1
}
), but is of size
{
shape_list
(
attn_scores
)
}
"
f
"
{
self
.
one_sided_attn_window_size
*
2
+
1
}
), but is of size
{
shape_list
(
attn_scores
)
}
"
),
),
)
)
# compute global attn indices required through out forward fn
# compute global attn indices required through out forward fn
(
(
...
@@ -826,15 +824,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -826,15 +824,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
(
message
=
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"
{
shape_list
(
layer_head_mask
)
}
"
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
),
)
)
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
...
@@ -857,12 +854,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -857,12 +854,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
),
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
message
=
"Unexpected size"
shape_list
(
attn_output
),
)
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
message
=
"Unexpected size"
,
)
attn_output
=
tf
.
reshape
(
attn_output
,
(
batch_size
,
seq_len
,
embed_dim
))
attn_output
=
tf
.
reshape
(
attn_output
,
(
batch_size
,
seq_len
,
embed_dim
))
...
@@ -917,20 +911,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -917,20 +911,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
"""
"""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
seq_len
%
(
window_overlap
*
2
),
0
,
0
,
message
=
f
"Sequence length should be multiple of
{
window_overlap
*
2
}
. Given
{
seq_len
}
"
,
message
=
f
"Sequence length should be multiple of
{
window_overlap
*
2
}
. Given
{
seq_len
}
"
,
)
)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
query
),
shape_list
(
query
),
shape_list
(
key
),
shape_list
(
key
),
message
=
(
message
=
(
f
"Shape of query and key should be equal, but got query:
{
shape_list
(
query
)
}
and key:"
f
"Shape of query and key should be equal, but got query:
{
shape_list
(
query
)
}
and key:"
f
"
{
shape_list
(
key
)
}
"
f
"
{
shape_list
(
key
)
}
"
),
),
)
)
chunks_count
=
seq_len
//
window_overlap
-
1
chunks_count
=
seq_len
//
window_overlap
-
1
...
@@ -1064,22 +1057,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1064,22 +1057,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
message
=
"Seq_len has to be multiple of 2 * window_overlap"
seq_len
%
(
window_overlap
*
2
),
)
0
,
tf
.
debugging
.
assert_equal
(
message
=
"Seq_len has to be multiple of 2 * window_overlap"
,
shape_list
(
attn_probs
)[:
3
],
)
shape_list
(
value
)[:
3
],
tf
.
debugging
.
assert_equal
(
message
=
"value and attn_probs must have same dims (except head_dim)"
,
shape_list
(
attn_probs
)[:
3
],
)
shape_list
(
value
)[:
3
],
tf
.
debugging
.
assert_equal
(
message
=
"value and attn_probs must have same dims (except head_dim)"
,
shape_list
(
attn_probs
)[
3
],
)
2
*
window_overlap
+
1
,
tf
.
debugging
.
assert_equal
(
message
=
"attn_probs last dim has to be 2 * window_overlap + 1"
,
shape_list
(
attn_probs
)[
3
],
)
2
*
window_overlap
+
1
,
message
=
"attn_probs last dim has to be 2 * window_overlap + 1"
,
)
chunks_count
=
seq_len
//
window_overlap
-
1
chunks_count
=
seq_len
//
window_overlap
-
1
...
@@ -1117,12 +1107,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1117,12 +1107,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_value
),
shape_list
(
chunked_value
),
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
message
=
"Chunked value has the wrong shape"
,
message
=
"Chunked value has the wrong shape"
,
)
)
chunked_attn_probs
=
self
.
_pad_and_diagonalize
(
chunked_attn_probs
)
chunked_attn_probs
=
self
.
_pad_and_diagonalize
(
chunked_attn_probs
)
context
=
tf
.
einsum
(
"bcwd,bcdh->bcwh"
,
chunked_attn_probs
,
chunked_value
)
context
=
tf
.
einsum
(
"bcwd,bcdh->bcwh"
,
chunked_attn_probs
,
chunked_value
)
...
@@ -1210,15 +1199,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1210,15 +1199,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
# chunk with overlap
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_hidden_states
),
shape_list
(
chunked_hidden_states
),
[
batch_size
,
num_output_chunks
,
frame_size
],
[
batch_size
,
num_output_chunks
,
frame_size
],
message
=
(
message
=
(
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
f
"
{
[
batch_size
,
frame_size
,
num_output_chunks
]
}
, but got
{
shape_list
(
chunked_hidden_states
)
}
."
f
"
{
[
batch_size
,
frame_size
,
num_output_chunks
]
}
, but got
{
shape_list
(
chunked_hidden_states
)
}
."
),
),
)
)
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
,
chunked_hidden_states
,
...
@@ -1391,16 +1379,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1391,16 +1379,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# compute attn scores
# compute attn scores
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_scores
),
shape_list
(
global_attn_scores
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
message
=
(
message
=
(
"global_attn_scores have the wrong size. Size should be"
"global_attn_scores have the wrong size. Size should be"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
}
, but is"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
}
, but is"
f
"
{
shape_list
(
global_attn_scores
)
}
."
f
"
{
shape_list
(
global_attn_scores
)
}
."
),
),
)
)
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
,
global_attn_scores
,
...
@@ -1434,15 +1421,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1434,15 +1421,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# apply layer head masking
# apply layer head masking
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
(
message
=
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"
{
shape_list
(
layer_head_mask
)
}
"
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
),
)
)
global_attn_probs_float
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
global_attn_probs_float
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
global_attn_probs_float
,
(
batch_size
,
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
global_attn_probs_float
,
(
batch_size
,
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
)
)
...
@@ -1456,16 +1442,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1456,16 +1442,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# global attn output
# global attn output
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_output
),
shape_list
(
global_attn_output
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
message
=
(
message
=
(
"global_attn_output tensor has the wrong size. Size should be"
"global_attn_output tensor has the wrong size. Size should be"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
)
}
, but is"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
)
}
, but is"
f
"
{
shape_list
(
global_attn_output
)
}
."
f
"
{
shape_list
(
global_attn_output
)
}
."
),
),
)
)
global_attn_output
=
tf
.
reshape
(
global_attn_output
=
tf
.
reshape
(
global_attn_output
,
global_attn_output
,
...
...
src/transformers/models/marian/modeling_tf_marian.py
View file @
31be02f1
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -264,31 +263,25 @@ class TFMarianAttention(tf.keras.layers.Layer):
...
@@ -264,31 +263,25 @@ class TFMarianAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -296,17 +289,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
...
@@ -296,17 +289,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -316,17 +306,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
...
@@ -316,17 +306,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -375,14 +362,11 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
...
@@ -375,14 +362,11 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -801,9 +785,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
...
@@ -801,9 +785,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -1009,10 +991,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
...
@@ -1009,10 +991,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/mbart/modeling_tf_mbart.py
View file @
31be02f1
...
@@ -232,31 +232,25 @@ class TFMBartAttention(tf.keras.layers.Layer):
...
@@ -232,31 +232,25 @@ class TFMBartAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -264,17 +258,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
...
@@ -264,17 +258,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -284,17 +275,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
...
@@ -284,17 +275,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -343,14 +331,11 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
...
@@ -343,14 +331,11 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -786,9 +771,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
...
@@ -786,9 +771,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -1001,10 +984,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
...
@@ -1001,10 +984,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/opt/modeling_tf_opt.py
View file @
31be02f1
...
@@ -206,31 +206,25 @@ class TFOPTAttention(tf.keras.layers.Layer):
...
@@ -206,31 +206,25 @@ class TFOPTAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -238,17 +232,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
...
@@ -238,17 +232,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -258,17 +249,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
...
@@ -258,17 +249,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -664,10 +652,8 @@ class TFOPTDecoder(tf.keras.layers.Layer):
...
@@ -664,10 +652,8 @@ class TFOPTDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/pegasus/modeling_tf_pegasus.py
View file @
31be02f1
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -265,31 +264,25 @@ class TFPegasusAttention(tf.keras.layers.Layer):
...
@@ -265,31 +264,25 @@ class TFPegasusAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -297,17 +290,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
...
@@ -297,17 +290,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -317,17 +307,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
...
@@ -317,17 +307,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -377,14 +364,11 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
...
@@ -377,14 +364,11 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -804,9 +788,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
...
@@ -804,9 +788,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -1015,10 +997,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
...
@@ -1015,10 +997,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
View file @
31be02f1
...
@@ -74,13 +74,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -74,13 +74,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -324,31 +323,25 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
...
@@ -324,31 +323,25 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -356,17 +349,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
...
@@ -356,17 +349,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -376,17 +366,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
...
@@ -376,17 +366,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -434,14 +421,11 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
...
@@ -434,14 +421,11 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
training
=
training
,
training
=
training
,
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -866,8 +850,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
...
@@ -866,8 +850,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
if
head_mask
is
not
None
:
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -1068,9 +1051,8 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
...
@@ -1068,9 +1051,8 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
next_decoder_cache
=
()
if
use_cache
else
None
next_decoder_cache
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
View file @
31be02f1
...
@@ -161,13 +161,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -161,13 +161,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
...
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
View file @
31be02f1
...
@@ -852,31 +852,25 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
...
@@ -852,31 +852,25 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -884,17 +878,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
...
@@ -884,17 +878,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -904,17 +895,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
...
@@ -904,17 +895,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
...
src/transformers/models/xglm/modeling_tf_xglm.py
View file @
31be02f1
...
@@ -239,31 +239,25 @@ class TFXGLMAttention(tf.keras.layers.Layer):
...
@@ -239,31 +239,25 @@ class TFXGLMAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -271,17 +265,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
...
@@ -271,17 +265,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -291,17 +282,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
...
@@ -291,17 +282,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -568,10 +556,8 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
...
@@ -568,10 +556,8 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
next_decoder_cache
=
()
if
use_cache
else
None
next_decoder_cache
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/xlm/modeling_tf_xlm.py
View file @
31be02f1
...
@@ -105,9 +105,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
...
@@ -105,9 +105,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
# sanity check
# sanity check
# assert shape_list(mask) == [bs, slen]
# assert shape_list(mask) == [bs, slen]
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
mask
),
[
bs
,
slen
])
tf
.
debugging
.
assert_equal
(
shape_list
(
mask
),
[
bs
,
slen
])
if
causal
:
assert
causal
is
False
or
shape_list
(
attn_mask
)
==
[
bs
,
slen
,
slen
]
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)
,
[
bs
,
slen
,
slen
]
)
return
mask
,
attn_mask
return
mask
,
attn_mask
...
@@ -384,10 +384,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -384,10 +384,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# check inputs
# check inputs
# assert shape_list(lengths)[0] == bs
# assert shape_list(lengths)[0] == bs
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
lengths
)[
0
],
bs
shape_list
(
lengths
)[
0
],
bs
),
f
"Expected batch size
{
shape_list
(
lengths
)[
0
]
}
and received batch size
{
bs
}
mismatched"
),
f
"Expected batch size
{
shape_list
(
lengths
)[
0
]
}
and received batch size
{
bs
}
mismatched"
# assert lengths.max().item() <= slen
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
# assert (src_enc is None) == (src_len is None)
...
@@ -405,15 +404,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -405,15 +404,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
position_ids
=
tf
.
expand_dims
(
tf
.
range
(
slen
),
axis
=
0
)
position_ids
=
tf
.
expand_dims
(
tf
.
range
(
slen
),
axis
=
0
)
position_ids
=
tf
.
tile
(
position_ids
,
(
bs
,
1
))
position_ids
=
tf
.
tile
(
position_ids
,
(
bs
,
1
))
if
tf
.
executing_eagerly
():
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
position_ids
),
[
bs
,
slen
]
shape_list
(
position_ids
),
[
bs
,
slen
]
),
f
"Position id shape
{
shape_list
(
position_ids
)
}
and input shape
{
[
bs
,
slen
]
}
mismatched"
),
f
"Position id shape
{
shape_list
(
position_ids
)
}
and input shape
{
[
bs
,
slen
]
}
mismatched"
# position_ids = position_ids.transpose(0, 1)
# position_ids = position_ids.transpose(0, 1)
# langs
# langs
if
langs
is
not
None
and
tf
.
executing_eagerly
()
:
if
langs
is
not
None
:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
langs
),
[
bs
,
slen
]
shape_list
(
langs
),
[
bs
,
slen
]
...
...
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
View file @
31be02f1
...
@@ -1693,13 +1693,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -1693,13 +1693,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -1837,24 +1836,18 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
...
@@ -1837,24 +1836,18 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
tf
.
debugging
.
assert_equal
(
message
=
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attn_weights
)
}
"
,
shape_list
(
attn_weights
),
)
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attn_weights
)
}
"
,
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attention_mask
),
if
tf
.
executing_eagerly
():
[
bsz
,
1
,
tgt_len
,
src_len
],
tf
.
debugging
.
assert_equal
(
message
=
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attention_mask
)
}
"
,
shape_list
(
attention_mask
),
)
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attention_mask
)
}
"
,
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -1862,14 +1855,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
...
@@ -1862,14 +1855,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
shape_list
(
layer_head_mask
),
)
[
self
.
num_heads
],
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -1880,14 +1870,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
...
@@ -1880,14 +1870,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is
{
shape_list
(
attn_output
)
}
"
,
shape_list
(
attn_output
),
)
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
message
=
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is
{
shape_list
(
attn_output
)
}
"
,
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -1929,14 +1916,11 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
...
@@ -1929,14 +1916,11 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
...
@@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
...
@@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment