Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
31be02f1
Unverified
Commit
31be02f1
authored
Sep 14, 2022
by
Joao Gante
Committed by
GitHub
Sep 14, 2022
Browse files
TF: tf.debugging assertions without tf.running_eagerly() protection (#19030)
parent
693ba2cc
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
696 additions
and
947 deletions
+696
-947
src/transformers/models/bart/modeling_tf_bart.py
src/transformers/models/bart/modeling_tf_bart.py
+42
-62
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
+42
-62
src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
...s/models/blenderbot_small/modeling_tf_blenderbot_small.py
+42
-62
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
...ers/models/encoder_decoder/modeling_tf_encoder_decoder.py
+5
-6
src/transformers/models/flaubert/modeling_tf_flaubert.py
src/transformers/models/flaubert/modeling_tf_flaubert.py
+12
-14
src/transformers/models/hubert/modeling_tf_hubert.py
src/transformers/models/hubert/modeling_tf_hubert.py
+30
-42
src/transformers/models/led/modeling_tf_led.py
src/transformers/models/led/modeling_tf_led.py
+130
-151
src/transformers/models/longformer/modeling_tf_longformer.py
src/transformers/models/longformer/modeling_tf_longformer.py
+89
-104
src/transformers/models/marian/modeling_tf_marian.py
src/transformers/models/marian/modeling_tf_marian.py
+42
-62
src/transformers/models/mbart/modeling_tf_mbart.py
src/transformers/models/mbart/modeling_tf_mbart.py
+37
-56
src/transformers/models/opt/modeling_tf_opt.py
src/transformers/models/opt/modeling_tf_opt.py
+31
-45
src/transformers/models/pegasus/modeling_tf_pegasus.py
src/transformers/models/pegasus/modeling_tf_pegasus.py
+42
-62
src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
...rmers/models/speech_to_text/modeling_tf_speech_to_text.py
+42
-60
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
...ion_encoder_decoder/modeling_tf_vision_encoder_decoder.py
+5
-6
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
+30
-42
src/transformers/models/xglm/modeling_tf_xglm.py
src/transformers/models/xglm/modeling_tf_xglm.py
+31
-45
src/transformers/models/xlm/modeling_tf_xlm.py
src/transformers/models/xlm/modeling_tf_xlm.py
+12
-14
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
...ame}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
+32
-52
No files found.
src/transformers/models/bart/modeling_tf_bart.py
View file @
31be02f1
...
@@ -71,13 +71,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -71,13 +71,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -229,31 +228,25 @@ class TFBartAttention(tf.keras.layers.Layer):
...
@@ -229,31 +228,25 @@ class TFBartAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -261,17 +254,14 @@ class TFBartAttention(tf.keras.layers.Layer):
...
@@ -261,17 +254,14 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -281,17 +271,14 @@ class TFBartAttention(tf.keras.layers.Layer):
...
@@ -281,17 +271,14 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -339,14 +326,11 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
...
@@ -339,14 +326,11 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -776,9 +760,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
...
@@ -776,9 +760,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -983,10 +965,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
...
@@ -983,10 +965,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
View file @
31be02f1
...
@@ -73,13 +73,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -73,13 +73,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -225,31 +224,25 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
...
@@ -225,31 +224,25 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -257,17 +250,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
...
@@ -257,17 +250,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -277,17 +267,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
...
@@ -277,17 +267,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -337,14 +324,11 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
...
@@ -337,14 +324,11 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -755,9 +739,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
...
@@ -755,9 +739,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -966,10 +948,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
...
@@ -966,10 +948,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
View file @
31be02f1
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -225,31 +224,25 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
...
@@ -225,31 +224,25 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -257,17 +250,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
...
@@ -257,17 +250,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -277,17 +267,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
...
@@ -277,17 +267,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -336,14 +323,11 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
...
@@ -336,14 +323,11 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -761,9 +745,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
...
@@ -761,9 +745,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -968,10 +950,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
...
@@ -968,10 +950,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
View file @
31be02f1
...
@@ -171,13 +171,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -171,13 +171,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
...
src/transformers/models/flaubert/modeling_tf_flaubert.py
View file @
31be02f1
...
@@ -200,9 +200,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
...
@@ -200,9 +200,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
# sanity check
# sanity check
# assert shape_list(mask) == [bs, slen]
# assert shape_list(mask) == [bs, slen]
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
mask
),
[
bs
,
slen
])
tf
.
debugging
.
assert_equal
(
shape_list
(
mask
),
[
bs
,
slen
])
if
causal
:
assert
causal
is
False
or
shape_list
(
attn_mask
)
==
[
bs
,
slen
,
slen
]
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)
,
[
bs
,
slen
,
slen
]
)
return
mask
,
attn_mask
return
mask
,
attn_mask
...
@@ -517,10 +517,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
...
@@ -517,10 +517,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# check inputs
# check inputs
# assert shape_list(lengths)[0] == bs
# assert shape_list(lengths)[0] == bs
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
lengths
)[
0
],
bs
shape_list
(
lengths
)[
0
],
bs
),
f
"Expected batch size
{
shape_list
(
lengths
)[
0
]
}
and received batch size
{
bs
}
mismatched"
),
f
"Expected batch size
{
shape_list
(
lengths
)[
0
]
}
and received batch size
{
bs
}
mismatched"
# assert lengths.max().item() <= slen
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
# assert (src_enc is None) == (src_len is None)
...
@@ -538,15 +537,14 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
...
@@ -538,15 +537,14 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
position_ids
=
tf
.
expand_dims
(
tf
.
range
(
slen
),
axis
=
0
)
position_ids
=
tf
.
expand_dims
(
tf
.
range
(
slen
),
axis
=
0
)
position_ids
=
tf
.
tile
(
position_ids
,
(
bs
,
1
))
position_ids
=
tf
.
tile
(
position_ids
,
(
bs
,
1
))
if
tf
.
executing_eagerly
():
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
position_ids
),
[
bs
,
slen
]
shape_list
(
position_ids
),
[
bs
,
slen
]
),
f
"Position id shape
{
shape_list
(
position_ids
)
}
and input shape
{
[
bs
,
slen
]
}
mismatched"
),
f
"Position id shape
{
shape_list
(
position_ids
)
}
and input shape
{
[
bs
,
slen
]
}
mismatched"
# position_ids = position_ids.transpose(0, 1)
# position_ids = position_ids.transpose(0, 1)
# langs
# langs
if
langs
is
not
None
and
tf
.
executing_eagerly
()
:
if
langs
is
not
None
:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
langs
),
[
bs
,
slen
]
shape_list
(
langs
),
[
bs
,
slen
]
...
...
src/transformers/models/hubert/modeling_tf_hubert.py
View file @
31be02f1
...
@@ -816,31 +816,25 @@ class TFHubertAttention(tf.keras.layers.Layer):
...
@@ -816,31 +816,25 @@ class TFHubertAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -848,17 +842,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
...
@@ -848,17 +842,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -868,17 +859,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
...
@@ -868,17 +859,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
...
src/transformers/models/led/modeling_tf_led.py
View file @
31be02f1
...
@@ -64,12 +64,11 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -64,12 +64,11 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
)
)
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
if
tf
.
executing_eagerly
():
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -213,12 +212,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -213,12 +212,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
value_vectors
=
self
.
value
(
hidden_states
)
value_vectors
=
self
.
value
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
embed_dim
,
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
message
=
f
"hidden_states should have embed_dim =
{
self
.
embed_dim
}
, but has
{
embed_dim
}
"
,
message
=
f
"hidden_states should have embed_dim =
{
self
.
embed_dim
}
, but has
{
embed_dim
}
"
,
)
)
# normalize query
# normalize query
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
...
@@ -245,15 +243,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -245,15 +243,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
# pad local attention probs
attn_scores
+=
diagonal_mask
attn_scores
+=
diagonal_mask
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_scores
),
shape_list
(
attn_scores
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
message
=
(
message
=
(
f
"attn_probs should be of size (
{
batch_size
}
,
{
seq_len
}
,
{
self
.
num_heads
}
,"
f
"attn_probs should be of size (
{
batch_size
}
,
{
seq_len
}
,
{
self
.
num_heads
}
,"
f
"
{
self
.
one_sided_attn_window_size
*
2
+
1
}
), but is of size
{
shape_list
(
attn_scores
)
}
"
f
"
{
self
.
one_sided_attn_window_size
*
2
+
1
}
), but is of size
{
shape_list
(
attn_scores
)
}
"
),
),
)
)
# compute global attn indices required through out forward fn
# compute global attn indices required through out forward fn
(
(
...
@@ -301,15 +298,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -301,15 +298,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
(
message
=
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"
{
shape_list
(
layer_head_mask
)
}
"
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
),
)
)
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
...
@@ -332,12 +328,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -332,12 +328,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
),
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
message
=
"Unexpected size"
shape_list
(
attn_output
),
)
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
message
=
"Unexpected size"
,
)
attn_output
=
tf
.
reshape
(
attn_output
,
(
batch_size
,
seq_len
,
embed_dim
))
attn_output
=
tf
.
reshape
(
attn_output
,
(
batch_size
,
seq_len
,
embed_dim
))
...
@@ -392,20 +385,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -392,20 +385,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
"""
"""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
seq_len
%
(
window_overlap
*
2
),
0
,
0
,
message
=
f
"Sequence length should be multiple of
{
window_overlap
*
2
}
. Given
{
seq_len
}
"
,
message
=
f
"Sequence length should be multiple of
{
window_overlap
*
2
}
. Given
{
seq_len
}
"
,
)
)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
query
),
shape_list
(
query
),
shape_list
(
key
),
shape_list
(
key
),
message
=
(
message
=
(
f
"Shape of query and key should be equal, but got query:
{
shape_list
(
query
)
}
and key:"
f
"Shape of query and key should be equal, but got query:
{
shape_list
(
query
)
}
and key:"
f
"
{
shape_list
(
key
)
}
"
f
"
{
shape_list
(
key
)
}
"
),
),
)
)
chunks_count
=
seq_len
//
window_overlap
-
1
chunks_count
=
seq_len
//
window_overlap
-
1
...
@@ -539,22 +531,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -539,22 +531,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
message
=
"Seq_len has to be multiple of 2 * window_overlap"
seq_len
%
(
window_overlap
*
2
),
)
0
,
tf
.
debugging
.
assert_equal
(
message
=
"Seq_len has to be multiple of 2 * window_overlap"
,
shape_list
(
attn_probs
)[:
3
],
)
shape_list
(
value
)[:
3
],
tf
.
debugging
.
assert_equal
(
message
=
"value and attn_probs must have same dims (except head_dim)"
,
shape_list
(
attn_probs
)[:
3
],
)
shape_list
(
value
)[:
3
],
tf
.
debugging
.
assert_equal
(
message
=
"value and attn_probs must have same dims (except head_dim)"
,
shape_list
(
attn_probs
)[
3
],
)
2
*
window_overlap
+
1
,
tf
.
debugging
.
assert_equal
(
message
=
"attn_probs last dim has to be 2 * window_overlap + 1"
,
shape_list
(
attn_probs
)[
3
],
)
2
*
window_overlap
+
1
,
message
=
"attn_probs last dim has to be 2 * window_overlap + 1"
,
)
chunks_count
=
seq_len
//
window_overlap
-
1
chunks_count
=
seq_len
//
window_overlap
-
1
...
@@ -592,12 +581,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -592,12 +581,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_value
),
shape_list
(
chunked_value
),
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
message
=
"Chunked value has the wrong shape"
,
message
=
"Chunked value has the wrong shape"
,
)
)
chunked_attn_probs
=
self
.
_pad_and_diagonalize
(
chunked_attn_probs
)
chunked_attn_probs
=
self
.
_pad_and_diagonalize
(
chunked_attn_probs
)
context
=
tf
.
einsum
(
"bcwd,bcdh->bcwh"
,
chunked_attn_probs
,
chunked_value
)
context
=
tf
.
einsum
(
"bcwd,bcdh->bcwh"
,
chunked_attn_probs
,
chunked_value
)
...
@@ -685,15 +673,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -685,15 +673,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
# chunk with overlap
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_hidden_states
),
shape_list
(
chunked_hidden_states
),
[
batch_size
,
num_output_chunks
,
frame_size
],
[
batch_size
,
num_output_chunks
,
frame_size
],
message
=
(
message
=
(
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
f
"
{
[
batch_size
,
frame_size
,
num_output_chunks
]
}
, but got
{
shape_list
(
chunked_hidden_states
)
}
."
f
"
{
[
batch_size
,
frame_size
,
num_output_chunks
]
}
, but got
{
shape_list
(
chunked_hidden_states
)
}
."
),
),
)
)
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
,
chunked_hidden_states
,
...
@@ -866,16 +853,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -866,16 +853,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# compute attn scores
# compute attn scores
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_scores
),
shape_list
(
global_attn_scores
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
message
=
(
message
=
(
"global_attn_scores have the wrong size. Size should be"
"global_attn_scores have the wrong size. Size should be"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
}
, but is"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
}
, but is"
f
"
{
shape_list
(
global_attn_scores
)
}
."
f
"
{
shape_list
(
global_attn_scores
)
}
."
),
),
)
)
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
,
global_attn_scores
,
...
@@ -909,15 +895,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -909,15 +895,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# apply layer head masking
# apply layer head masking
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
(
message
=
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"
{
shape_list
(
layer_head_mask
)
}
"
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
),
)
)
global_attn_probs_float
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
global_attn_probs_float
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
global_attn_probs_float
,
(
batch_size
,
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
global_attn_probs_float
,
(
batch_size
,
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
)
)
...
@@ -931,16 +916,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -931,16 +916,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# global attn output
# global attn output
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_output
),
shape_list
(
global_attn_output
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
message
=
(
message
=
(
"global_attn_output tensor has the wrong size. Size should be"
"global_attn_output tensor has the wrong size. Size should be"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
)
}
, but is"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
)
}
, but is"
f
"
{
shape_list
(
global_attn_output
)
}
."
f
"
{
shape_list
(
global_attn_output
)
}
."
),
),
)
)
global_attn_output
=
tf
.
reshape
(
global_attn_output
=
tf
.
reshape
(
global_attn_output
,
global_attn_output
,
...
@@ -1091,27 +1075,25 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
...
@@ -1091,27 +1075,25 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_weights
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
tf
.
cast
(
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
attention_mask
,
dtype
=
attn_weights
.
dtype
)
)
...
@@ -1120,15 +1102,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
...
@@ -1120,15 +1102,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
(
message
=
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"
{
shape_list
(
layer_head_mask
)
}
"
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
),
)
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -1139,15 +1120,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
...
@@ -1139,15 +1120,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
shape_list
(
attn_output
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
message
=
(
message
=
(
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
f
"
{
shape_list
(
attn_output
)
}
"
f
"
{
shape_list
(
attn_output
)
}
"
),
),
)
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -1199,12 +1179,11 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
...
@@ -1199,12 +1179,11 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
hidden_states
),
shape_list
(
hidden_states
),
shape_list
(
residual
),
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -1792,7 +1771,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
...
@@ -1792,7 +1771,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
all_attentions
=
all_global_attentions
=
()
if
output_attentions
else
None
all_attentions
=
all_global_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
if
head_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
head_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -2055,7 +2034,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
...
@@ -2055,7 +2034,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
present_key_values
=
()
present_key_values
=
()
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
if
head_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
head_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/longformer/modeling_tf_longformer.py
View file @
31be02f1
...
@@ -738,12 +738,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -738,12 +738,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors
=
self
.
value
(
hidden_states
)
value_vectors
=
self
.
value
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
embed_dim
,
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
message
=
f
"hidden_states should have embed_dim =
{
self
.
embed_dim
}
, but has
{
embed_dim
}
"
,
message
=
f
"hidden_states should have embed_dim =
{
self
.
embed_dim
}
, but has
{
embed_dim
}
"
,
)
)
# normalize query
# normalize query
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
...
@@ -770,15 +769,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -770,15 +769,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
# pad local attention probs
attn_scores
+=
diagonal_mask
attn_scores
+=
diagonal_mask
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_scores
),
shape_list
(
attn_scores
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
message
=
(
message
=
(
f
"attn_probs should be of size (
{
batch_size
}
,
{
seq_len
}
,
{
self
.
num_heads
}
,"
f
"attn_probs should be of size (
{
batch_size
}
,
{
seq_len
}
,
{
self
.
num_heads
}
,"
f
"
{
self
.
one_sided_attn_window_size
*
2
+
1
}
), but is of size
{
shape_list
(
attn_scores
)
}
"
f
"
{
self
.
one_sided_attn_window_size
*
2
+
1
}
), but is of size
{
shape_list
(
attn_scores
)
}
"
),
),
)
)
# compute global attn indices required through out forward fn
# compute global attn indices required through out forward fn
(
(
...
@@ -826,15 +824,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -826,15 +824,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
(
message
=
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"
{
shape_list
(
layer_head_mask
)
}
"
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
),
)
)
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
...
@@ -857,12 +854,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -857,12 +854,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
),
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
message
=
"Unexpected size"
shape_list
(
attn_output
),
)
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
message
=
"Unexpected size"
,
)
attn_output
=
tf
.
reshape
(
attn_output
,
(
batch_size
,
seq_len
,
embed_dim
))
attn_output
=
tf
.
reshape
(
attn_output
,
(
batch_size
,
seq_len
,
embed_dim
))
...
@@ -917,20 +911,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -917,20 +911,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
"""
"""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
seq_len
%
(
window_overlap
*
2
),
0
,
0
,
message
=
f
"Sequence length should be multiple of
{
window_overlap
*
2
}
. Given
{
seq_len
}
"
,
message
=
f
"Sequence length should be multiple of
{
window_overlap
*
2
}
. Given
{
seq_len
}
"
,
)
)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
query
),
shape_list
(
query
),
shape_list
(
key
),
shape_list
(
key
),
message
=
(
message
=
(
f
"Shape of query and key should be equal, but got query:
{
shape_list
(
query
)
}
and key:"
f
"Shape of query and key should be equal, but got query:
{
shape_list
(
query
)
}
and key:"
f
"
{
shape_list
(
key
)
}
"
f
"
{
shape_list
(
key
)
}
"
),
),
)
)
chunks_count
=
seq_len
//
window_overlap
-
1
chunks_count
=
seq_len
//
window_overlap
-
1
...
@@ -1064,22 +1057,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1064,22 +1057,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
message
=
"Seq_len has to be multiple of 2 * window_overlap"
seq_len
%
(
window_overlap
*
2
),
)
0
,
tf
.
debugging
.
assert_equal
(
message
=
"Seq_len has to be multiple of 2 * window_overlap"
,
shape_list
(
attn_probs
)[:
3
],
)
shape_list
(
value
)[:
3
],
tf
.
debugging
.
assert_equal
(
message
=
"value and attn_probs must have same dims (except head_dim)"
,
shape_list
(
attn_probs
)[:
3
],
)
shape_list
(
value
)[:
3
],
tf
.
debugging
.
assert_equal
(
message
=
"value and attn_probs must have same dims (except head_dim)"
,
shape_list
(
attn_probs
)[
3
],
)
2
*
window_overlap
+
1
,
tf
.
debugging
.
assert_equal
(
message
=
"attn_probs last dim has to be 2 * window_overlap + 1"
,
shape_list
(
attn_probs
)[
3
],
)
2
*
window_overlap
+
1
,
message
=
"attn_probs last dim has to be 2 * window_overlap + 1"
,
)
chunks_count
=
seq_len
//
window_overlap
-
1
chunks_count
=
seq_len
//
window_overlap
-
1
...
@@ -1117,12 +1107,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1117,12 +1107,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_value
),
shape_list
(
chunked_value
),
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
message
=
"Chunked value has the wrong shape"
,
message
=
"Chunked value has the wrong shape"
,
)
)
chunked_attn_probs
=
self
.
_pad_and_diagonalize
(
chunked_attn_probs
)
chunked_attn_probs
=
self
.
_pad_and_diagonalize
(
chunked_attn_probs
)
context
=
tf
.
einsum
(
"bcwd,bcdh->bcwh"
,
chunked_attn_probs
,
chunked_value
)
context
=
tf
.
einsum
(
"bcwd,bcdh->bcwh"
,
chunked_attn_probs
,
chunked_value
)
...
@@ -1210,15 +1199,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1210,15 +1199,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
# chunk with overlap
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_hidden_states
),
shape_list
(
chunked_hidden_states
),
[
batch_size
,
num_output_chunks
,
frame_size
],
[
batch_size
,
num_output_chunks
,
frame_size
],
message
=
(
message
=
(
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
f
"
{
[
batch_size
,
frame_size
,
num_output_chunks
]
}
, but got
{
shape_list
(
chunked_hidden_states
)
}
."
f
"
{
[
batch_size
,
frame_size
,
num_output_chunks
]
}
, but got
{
shape_list
(
chunked_hidden_states
)
}
."
),
),
)
)
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
,
chunked_hidden_states
,
...
@@ -1391,16 +1379,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1391,16 +1379,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# compute attn scores
# compute attn scores
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_scores
),
shape_list
(
global_attn_scores
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
message
=
(
message
=
(
"global_attn_scores have the wrong size. Size should be"
"global_attn_scores have the wrong size. Size should be"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
}
, but is"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
}
, but is"
f
"
{
shape_list
(
global_attn_scores
)
}
."
f
"
{
shape_list
(
global_attn_scores
)
}
."
),
),
)
)
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
,
global_attn_scores
,
...
@@ -1434,15 +1421,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1434,15 +1421,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# apply layer head masking
# apply layer head masking
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
(
message
=
(
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
f
"
{
shape_list
(
layer_head_mask
)
}
"
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
),
)
)
global_attn_probs_float
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
global_attn_probs_float
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
global_attn_probs_float
,
(
batch_size
,
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
global_attn_probs_float
,
(
batch_size
,
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
)
)
)
...
@@ -1456,16 +1442,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1456,16 +1442,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# global attn output
# global attn output
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_output
),
shape_list
(
global_attn_output
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
message
=
(
message
=
(
"global_attn_output tensor has the wrong size. Size should be"
"global_attn_output tensor has the wrong size. Size should be"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
)
}
, but is"
f
"
{
(
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
)
}
, but is"
f
"
{
shape_list
(
global_attn_output
)
}
."
f
"
{
shape_list
(
global_attn_output
)
}
."
),
),
)
)
global_attn_output
=
tf
.
reshape
(
global_attn_output
=
tf
.
reshape
(
global_attn_output
,
global_attn_output
,
...
...
src/transformers/models/marian/modeling_tf_marian.py
View file @
31be02f1
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -264,31 +263,25 @@ class TFMarianAttention(tf.keras.layers.Layer):
...
@@ -264,31 +263,25 @@ class TFMarianAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -296,17 +289,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
...
@@ -296,17 +289,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -316,17 +306,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
...
@@ -316,17 +306,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -375,14 +362,11 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
...
@@ -375,14 +362,11 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -801,9 +785,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
...
@@ -801,9 +785,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -1009,10 +991,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
...
@@ -1009,10 +991,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/mbart/modeling_tf_mbart.py
View file @
31be02f1
...
@@ -232,31 +232,25 @@ class TFMBartAttention(tf.keras.layers.Layer):
...
@@ -232,31 +232,25 @@ class TFMBartAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -264,17 +258,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
...
@@ -264,17 +258,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -284,17 +275,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
...
@@ -284,17 +275,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -343,14 +331,11 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
...
@@ -343,14 +331,11 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -786,9 +771,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
...
@@ -786,9 +771,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -1001,10 +984,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
...
@@ -1001,10 +984,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/opt/modeling_tf_opt.py
View file @
31be02f1
...
@@ -206,31 +206,25 @@ class TFOPTAttention(tf.keras.layers.Layer):
...
@@ -206,31 +206,25 @@ class TFOPTAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -238,17 +232,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
...
@@ -238,17 +232,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -258,17 +249,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
...
@@ -258,17 +249,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -664,10 +652,8 @@ class TFOPTDecoder(tf.keras.layers.Layer):
...
@@ -664,10 +652,8 @@ class TFOPTDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/pegasus/modeling_tf_pegasus.py
View file @
31be02f1
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -265,31 +264,25 @@ class TFPegasusAttention(tf.keras.layers.Layer):
...
@@ -265,31 +264,25 @@ class TFPegasusAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -297,17 +290,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
...
@@ -297,17 +290,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -317,17 +307,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
...
@@ -317,17 +307,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -377,14 +364,11 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
...
@@ -377,14 +364,11 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -804,9 +788,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
...
@@ -804,9 +788,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -1015,10 +997,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
...
@@ -1015,10 +997,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
View file @
31be02f1
...
@@ -74,13 +74,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -74,13 +74,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -324,31 +323,25 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
...
@@ -324,31 +323,25 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -356,17 +349,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
...
@@ -356,17 +349,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -376,17 +366,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
...
@@ -376,17 +366,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -434,14 +421,11 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
...
@@ -434,14 +421,11 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
training
=
training
,
training
=
training
,
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -866,8 +850,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
...
@@ -866,8 +850,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
if
head_mask
is
not
None
:
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -1068,9 +1051,8 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
...
@@ -1068,9 +1051,8 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
next_decoder_cache
=
()
if
use_cache
else
None
next_decoder_cache
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
View file @
31be02f1
...
@@ -161,13 +161,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -161,13 +161,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
,
dtype
=
input_ids
.
dtype
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
...
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
View file @
31be02f1
...
@@ -852,31 +852,25 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
...
@@ -852,31 +852,25 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -884,17 +878,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
...
@@ -884,17 +878,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -904,17 +895,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
...
@@ -904,17 +895,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
...
src/transformers/models/xglm/modeling_tf_xglm.py
View file @
31be02f1
...
@@ -239,31 +239,25 @@ class TFXGLMAttention(tf.keras.layers.Layer):
...
@@ -239,31 +239,25 @@ class TFXGLMAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attn_weights
)
}
"
),
)
if
attention_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
att
n_weights
),
shape_list
(
att
ention_mask
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
message
=
(
f
"Attention
weights
should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention
mask
should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
att
n_weights
)
}
"
f
"
{
shape_list
(
att
ention_mask
)
}
"
),
),
)
)
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
shape_list
(
attention_mask
)
}
"
),
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -271,17 +265,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
...
@@ -271,17 +265,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
layer_head_mask
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
[
self
.
num_heads
],
f
"
{
shape_list
(
layer_head_mask
)
}
"
message
=
(
),
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is"
)
f
"
{
shape_list
(
layer_head_mask
)
}
"
),
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -291,17 +282,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
...
@@ -291,17 +282,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_probs
=
self
.
dropout
(
attn_weights
,
training
=
training
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
(
shape_list
(
attn_output
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
f
"
{
shape_list
(
attn_output
)
}
"
message
=
(
),
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is"
)
f
"
{
shape_list
(
attn_output
)
}
"
),
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -568,10 +556,8 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
...
@@ -568,10 +556,8 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
next_decoder_cache
=
()
if
use_cache
else
None
next_decoder_cache
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
src/transformers/models/xlm/modeling_tf_xlm.py
View file @
31be02f1
...
@@ -105,9 +105,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
...
@@ -105,9 +105,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
# sanity check
# sanity check
# assert shape_list(mask) == [bs, slen]
# assert shape_list(mask) == [bs, slen]
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
mask
),
[
bs
,
slen
])
tf
.
debugging
.
assert_equal
(
shape_list
(
mask
),
[
bs
,
slen
])
if
causal
:
assert
causal
is
False
or
shape_list
(
attn_mask
)
==
[
bs
,
slen
,
slen
]
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)
,
[
bs
,
slen
,
slen
]
)
return
mask
,
attn_mask
return
mask
,
attn_mask
...
@@ -384,10 +384,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -384,10 +384,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# check inputs
# check inputs
# assert shape_list(lengths)[0] == bs
# assert shape_list(lengths)[0] == bs
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
lengths
)[
0
],
bs
shape_list
(
lengths
)[
0
],
bs
),
f
"Expected batch size
{
shape_list
(
lengths
)[
0
]
}
and received batch size
{
bs
}
mismatched"
),
f
"Expected batch size
{
shape_list
(
lengths
)[
0
]
}
and received batch size
{
bs
}
mismatched"
# assert lengths.max().item() <= slen
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
# assert (src_enc is None) == (src_len is None)
...
@@ -405,15 +404,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -405,15 +404,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
position_ids
=
tf
.
expand_dims
(
tf
.
range
(
slen
),
axis
=
0
)
position_ids
=
tf
.
expand_dims
(
tf
.
range
(
slen
),
axis
=
0
)
position_ids
=
tf
.
tile
(
position_ids
,
(
bs
,
1
))
position_ids
=
tf
.
tile
(
position_ids
,
(
bs
,
1
))
if
tf
.
executing_eagerly
():
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
position_ids
),
[
bs
,
slen
]
shape_list
(
position_ids
),
[
bs
,
slen
]
),
f
"Position id shape
{
shape_list
(
position_ids
)
}
and input shape
{
[
bs
,
slen
]
}
mismatched"
),
f
"Position id shape
{
shape_list
(
position_ids
)
}
and input shape
{
[
bs
,
slen
]
}
mismatched"
# position_ids = position_ids.transpose(0, 1)
# position_ids = position_ids.transpose(0, 1)
# langs
# langs
if
langs
is
not
None
and
tf
.
executing_eagerly
()
:
if
langs
is
not
None
:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
langs
),
[
bs
,
slen
]
shape_list
(
langs
),
[
bs
,
slen
]
...
...
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
View file @
31be02f1
...
@@ -1693,13 +1693,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -1693,13 +1693,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
shifted_input_ids
==
-
100
,
tf
.
fill
(
shape_list
(
shifted_input_ids
),
pad_token_id
),
shifted_input_ids
)
)
if
tf
.
executing_eagerly
():
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
))
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
shifted_input_ids
=
tf
.
identity
(
shifted_input_ids
)
return
shifted_input_ids
return
shifted_input_ids
...
@@ -1837,24 +1836,18 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
...
@@ -1837,24 +1836,18 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_weights
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
tf
.
debugging
.
assert_equal
(
message
=
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attn_weights
)
}
"
,
shape_list
(
attn_weights
),
)
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
message
=
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attn_weights
)
}
"
,
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attention_mask
),
if
tf
.
executing_eagerly
():
[
bsz
,
1
,
tgt_len
,
src_len
],
tf
.
debugging
.
assert_equal
(
message
=
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attention_mask
)
}
"
,
shape_list
(
attention_mask
),
)
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attention_mask
)
}
"
,
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
...
@@ -1862,14 +1855,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
...
@@ -1862,14 +1855,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
stable_softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
layer_head_mask
),
if
tf
.
executing_eagerly
():
[
self
.
num_heads
],
tf
.
debugging
.
assert_equal
(
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
shape_list
(
layer_head_mask
),
)
[
self
.
num_heads
],
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
@@ -1880,14 +1870,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
...
@@ -1880,14 +1870,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
attn_output
),
if
tf
.
executing_eagerly
():
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
tf
.
debugging
.
assert_equal
(
message
=
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is
{
shape_list
(
attn_output
)
}
"
,
shape_list
(
attn_output
),
)
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
message
=
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is
{
shape_list
(
attn_output
)
}
"
,
)
attn_output
=
tf
.
transpose
(
attn_output
=
tf
.
transpose
(
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
tf
.
reshape
(
attn_output
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)),
(
0
,
2
,
1
,
3
)
...
@@ -1929,14 +1916,11 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
...
@@ -1929,14 +1916,11 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
layer_head_mask
=
layer_head_mask
)
)
# The tf.debugging asserts are not compliant with XLA then they
tf
.
debugging
.
assert_equal
(
# have to be disabled in other modes than eager.
shape_list
(
hidden_states
),
if
tf
.
executing_eagerly
():
shape_list
(
residual
),
tf
.
debugging
.
assert_equal
(
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
shape_list
(
hidden_states
),
)
shape_list
(
residual
),
message
=
f
"Self attn modified the shape of query
{
shape_list
(
residual
)
}
to
{
shape_list
(
hidden_states
)
}
"
,
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
...
@@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
if
head_mask
is
not
None
:
# have to be disabled in other modes than eager.
if
head_mask
is
not
None
and
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
head_mask
)[
0
],
shape_list
(
head_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
@@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
...
@@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
present_key_values
=
()
if
use_cache
else
None
present_key_values
=
()
if
use_cache
else
None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
for
attn_mask_name
,
attn_mask
in
[(
"head_mask"
,
head_mask
),
(
"cross_attn_head_mask"
,
cross_attn_head_mask
)]:
if
attn_mask
is
not
None
and
tf
.
executing_eagerly
()
:
if
attn_mask
is
not
None
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_mask
)[
0
],
shape_list
(
attn_mask
)[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment