Unverified Commit 31be02f1 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: tf.debugging assertions without tf.running_eagerly() protection (#19030)

parent 693ba2cc
...@@ -71,13 +71,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -71,13 +71,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
) )
if tf.executing_eagerly(): # "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100" assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids) shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
...@@ -229,31 +228,25 @@ class TFBartAttention(tf.keras.layers.Layer): ...@@ -229,31 +228,25 @@ class TFBartAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -261,17 +254,14 @@ class TFBartAttention(tf.keras.layers.Layer): ...@@ -261,17 +254,14 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -281,17 +271,14 @@ class TFBartAttention(tf.keras.layers.Layer): ...@@ -281,17 +271,14 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -339,14 +326,11 @@ class TFBartEncoderLayer(tf.keras.layers.Layer): ...@@ -339,14 +326,11 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(hidden_states),
if tf.executing_eagerly(): shape_list(residual),
tf.debugging.assert_equal( message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
shape_list(hidden_states), )
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -776,9 +760,7 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -776,9 +760,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they if head_mask is not None:
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(head_mask)[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
...@@ -983,10 +965,8 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -983,10 +965,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
......
...@@ -73,13 +73,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -73,13 +73,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
) )
if tf.executing_eagerly(): # "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100" assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids) shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
...@@ -225,31 +224,25 @@ class TFBlenderbotAttention(tf.keras.layers.Layer): ...@@ -225,31 +224,25 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -257,17 +250,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer): ...@@ -257,17 +250,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -277,17 +267,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer): ...@@ -277,17 +267,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -337,14 +324,11 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer): ...@@ -337,14 +324,11 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(hidden_states),
if tf.executing_eagerly(): shape_list(residual),
tf.debugging.assert_equal( message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
shape_list(hidden_states), )
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -755,9 +739,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): ...@@ -755,9 +739,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they if head_mask is not None:
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(head_mask)[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
...@@ -966,10 +948,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -966,10 +948,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
......
...@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
) )
if tf.executing_eagerly(): # "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100" assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids) shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
...@@ -225,31 +224,25 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer): ...@@ -225,31 +224,25 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -257,17 +250,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer): ...@@ -257,17 +250,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -277,17 +267,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer): ...@@ -277,17 +267,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -336,14 +323,11 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer): ...@@ -336,14 +323,11 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(hidden_states),
if tf.executing_eagerly(): shape_list(residual),
tf.debugging.assert_equal( message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
shape_list(hidden_states), )
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -761,9 +745,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer): ...@@ -761,9 +745,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they if head_mask is not None:
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(head_mask)[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
...@@ -968,10 +950,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -968,10 +950,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
......
...@@ -171,13 +171,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -171,13 +171,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
) )
if tf.executing_eagerly(): # "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100" assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids) shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
......
...@@ -200,9 +200,9 @@ def get_masks(slen, lengths, causal, padding_mask=None): ...@@ -200,9 +200,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
# sanity check # sanity check
# assert shape_list(mask) == [bs, slen] # assert shape_list(mask) == [bs, slen]
if tf.executing_eagerly(): tf.debugging.assert_equal(shape_list(mask), [bs, slen])
tf.debugging.assert_equal(shape_list(mask), [bs, slen]) if causal:
assert causal is False or shape_list(attn_mask) == [bs, slen, slen] tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
return mask, attn_mask return mask, attn_mask
...@@ -517,10 +517,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -517,10 +517,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(lengths)[0], bs
shape_list(lengths)[0], bs ), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
...@@ -538,15 +537,14 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -538,15 +537,14 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
position_ids = tf.expand_dims(tf.range(slen), axis=0) position_ids = tf.expand_dims(tf.range(slen), axis=0)
position_ids = tf.tile(position_ids, (bs, 1)) position_ids = tf.tile(position_ids, (bs, 1))
if tf.executing_eagerly(): # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(position_ids), [bs, slen]
shape_list(position_ids), [bs, slen] ), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched" # position_ids = position_ids.transpose(0, 1)
# position_ids = position_ids.transpose(0, 1)
# langs # langs
if langs is not None and tf.executing_eagerly(): if langs is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(langs), [bs, slen] shape_list(langs), [bs, slen]
......
...@@ -816,31 +816,25 @@ class TFHubertAttention(tf.keras.layers.Layer): ...@@ -816,31 +816,25 @@ class TFHubertAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -848,17 +842,14 @@ class TFHubertAttention(tf.keras.layers.Layer): ...@@ -848,17 +842,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -868,17 +859,14 @@ class TFHubertAttention(tf.keras.layers.Layer): ...@@ -868,17 +859,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
......
...@@ -64,12 +64,11 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -64,12 +64,11 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
) )
# "Verify that `labels` has only positive values and -100" # "Verify that `labels` has only positive values and -100"
if tf.executing_eagerly(): assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids) shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
...@@ -213,12 +212,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -213,12 +212,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
value_vectors = self.value(hidden_states) value_vectors = self.value(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states) batch_size, seq_len, embed_dim = shape_list(hidden_states)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( embed_dim,
embed_dim, self.embed_dim,
self.embed_dim, message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", )
)
# normalize query # normalize query
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype)) query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
...@@ -245,15 +243,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -245,15 +243,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# pad local attention probs # pad local attention probs
attn_scores += diagonal_mask attn_scores += diagonal_mask
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(attn_scores),
shape_list(attn_scores), [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], message=(
message=( f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}," f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}" ),
), )
)
# compute global attn indices required through out forward fn # compute global attn indices required through out forward fn
( (
...@@ -301,15 +298,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -301,15 +298,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
) )
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(layer_head_mask),
shape_list(layer_head_mask), [self.num_heads],
[self.num_heads], message=(
message=( f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" f" {shape_list(layer_head_mask)}"
f" {shape_list(layer_head_mask)}" ),
), )
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
...@@ -332,12 +328,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -332,12 +328,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
), ),
) )
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
shape_list(attn_output), )
[batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size",
)
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
...@@ -392,20 +385,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -392,20 +385,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
""" """
batch_size, seq_len, num_heads, head_dim = shape_list(query) batch_size, seq_len, num_heads, head_dim = shape_list(query)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( seq_len % (window_overlap * 2),
seq_len % (window_overlap * 2), 0,
0, message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", )
) tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(query),
shape_list(query), shape_list(key),
shape_list(key), message=(
message=( f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:" f" {shape_list(key)}"
f" {shape_list(key)}" ),
), )
)
chunks_count = seq_len // window_overlap - 1 chunks_count = seq_len // window_overlap - 1
...@@ -539,22 +531,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -539,22 +531,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
batch_size, seq_len, num_heads, head_dim = shape_list(value) batch_size, seq_len, num_heads, head_dim = shape_list(value)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
seq_len % (window_overlap * 2), )
0, tf.debugging.assert_equal(
message="Seq_len has to be multiple of 2 * window_overlap", shape_list(attn_probs)[:3],
) shape_list(value)[:3],
tf.debugging.assert_equal( message="value and attn_probs must have same dims (except head_dim)",
shape_list(attn_probs)[:3], )
shape_list(value)[:3], tf.debugging.assert_equal(
message="value and attn_probs must have same dims (except head_dim)", shape_list(attn_probs)[3],
) 2 * window_overlap + 1,
tf.debugging.assert_equal( message="attn_probs last dim has to be 2 * window_overlap + 1",
shape_list(attn_probs)[3], )
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
chunks_count = seq_len // window_overlap - 1 chunks_count = seq_len // window_overlap - 1
...@@ -592,12 +581,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -592,12 +581,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
) )
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(chunked_value),
shape_list(chunked_value), [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], message="Chunked value has the wrong shape",
message="Chunked value has the wrong shape", )
)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
...@@ -685,15 +673,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -685,15 +673,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# chunk with overlap # chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(chunked_hidden_states),
shape_list(chunked_hidden_states), [batch_size, num_output_chunks, frame_size],
[batch_size, num_output_chunks, frame_size], message=(
message=( "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension" f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}." ),
), )
)
chunked_hidden_states = tf.reshape( chunked_hidden_states = tf.reshape(
chunked_hidden_states, chunked_hidden_states,
...@@ -866,16 +853,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -866,16 +853,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# compute attn scores # compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(global_attn_scores),
shape_list(global_attn_scores), [batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len], message=(
message=( "global_attn_scores have the wrong size. Size should be"
"global_attn_scores have the wrong size. Size should be" f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is" f" {shape_list(global_attn_scores)}."
f" {shape_list(global_attn_scores)}." ),
), )
)
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
global_attn_scores, global_attn_scores,
...@@ -909,15 +895,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -909,15 +895,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# apply layer head masking # apply layer head masking
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(layer_head_mask),
shape_list(layer_head_mask), [self.num_heads],
[self.num_heads], message=(
message=( f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" f" {shape_list(layer_head_mask)}"
f" {shape_list(layer_head_mask)}" ),
), )
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
) )
...@@ -931,16 +916,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -931,16 +916,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# global attn output # global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(global_attn_output),
shape_list(global_attn_output), [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], message=(
message=( "global_attn_output tensor has the wrong size. Size should be"
"global_attn_output tensor has the wrong size. Size should be" f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is" f" {shape_list(global_attn_output)}."
f" {shape_list(global_attn_output)}." ),
), )
)
global_attn_output = tf.reshape( global_attn_output = tf.reshape(
global_attn_output, global_attn_output,
...@@ -1091,27 +1075,25 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): ...@@ -1091,27 +1075,25 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
if tf.executing_eagerly(): tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast( attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast(
attention_mask, dtype=attn_weights.dtype attention_mask, dtype=attn_weights.dtype
) )
...@@ -1120,15 +1102,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): ...@@ -1120,15 +1102,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(layer_head_mask),
shape_list(layer_head_mask), [self.num_heads],
[self.num_heads], message=(
message=( f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" f" {shape_list(layer_head_mask)}"
f" {shape_list(layer_head_mask)}" ),
), )
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -1139,15 +1120,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): ...@@ -1139,15 +1120,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(attn_output),
shape_list(attn_output), [bsz * self.num_heads, tgt_len, self.head_dim],
[bsz * self.num_heads, tgt_len, self.head_dim], message=(
message=( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {shape_list(attn_output)}"
f" {shape_list(attn_output)}" ),
), )
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -1199,12 +1179,11 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer): ...@@ -1199,12 +1179,11 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(hidden_states),
shape_list(hidden_states), shape_list(residual),
shape_list(residual), message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", )
)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -1792,7 +1771,7 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1792,7 +1771,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
all_attentions = all_global_attentions = () if output_attentions else None all_attentions = all_global_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
if head_mask is not None and tf.executing_eagerly(): if head_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(head_mask)[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
...@@ -2055,7 +2034,7 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -2055,7 +2034,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
if head_mask is not None and tf.executing_eagerly(): if head_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(head_mask)[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
......
...@@ -738,12 +738,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -738,12 +738,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors = self.value(hidden_states) value_vectors = self.value(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states) batch_size, seq_len, embed_dim = shape_list(hidden_states)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( embed_dim,
embed_dim, self.embed_dim,
self.embed_dim, message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", )
)
# normalize query # normalize query
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype)) query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
...@@ -770,15 +769,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -770,15 +769,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# pad local attention probs # pad local attention probs
attn_scores += diagonal_mask attn_scores += diagonal_mask
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(attn_scores),
shape_list(attn_scores), [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], message=(
message=( f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}," f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}" ),
), )
)
# compute global attn indices required through out forward fn # compute global attn indices required through out forward fn
( (
...@@ -826,15 +824,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -826,15 +824,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(layer_head_mask),
shape_list(layer_head_mask), [self.num_heads],
[self.num_heads], message=(
message=( f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" f" {shape_list(layer_head_mask)}"
f" {shape_list(layer_head_mask)}" ),
), )
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
...@@ -857,12 +854,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -857,12 +854,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
), ),
) )
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
shape_list(attn_output), )
[batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size",
)
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
...@@ -917,20 +911,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -917,20 +911,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
""" """
batch_size, seq_len, num_heads, head_dim = shape_list(query) batch_size, seq_len, num_heads, head_dim = shape_list(query)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( seq_len % (window_overlap * 2),
seq_len % (window_overlap * 2), 0,
0, message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", )
) tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(query),
shape_list(query), shape_list(key),
shape_list(key), message=(
message=( f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:" f" {shape_list(key)}"
f" {shape_list(key)}" ),
), )
)
chunks_count = seq_len // window_overlap - 1 chunks_count = seq_len // window_overlap - 1
...@@ -1064,22 +1057,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1064,22 +1057,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
batch_size, seq_len, num_heads, head_dim = shape_list(value) batch_size, seq_len, num_heads, head_dim = shape_list(value)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
seq_len % (window_overlap * 2), )
0, tf.debugging.assert_equal(
message="Seq_len has to be multiple of 2 * window_overlap", shape_list(attn_probs)[:3],
) shape_list(value)[:3],
tf.debugging.assert_equal( message="value and attn_probs must have same dims (except head_dim)",
shape_list(attn_probs)[:3], )
shape_list(value)[:3], tf.debugging.assert_equal(
message="value and attn_probs must have same dims (except head_dim)", shape_list(attn_probs)[3],
) 2 * window_overlap + 1,
tf.debugging.assert_equal( message="attn_probs last dim has to be 2 * window_overlap + 1",
shape_list(attn_probs)[3], )
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
chunks_count = seq_len // window_overlap - 1 chunks_count = seq_len // window_overlap - 1
...@@ -1117,12 +1107,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1117,12 +1107,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
) )
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(chunked_value),
shape_list(chunked_value), [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], message="Chunked value has the wrong shape",
message="Chunked value has the wrong shape", )
)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
...@@ -1210,15 +1199,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1210,15 +1199,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# chunk with overlap # chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(chunked_hidden_states),
shape_list(chunked_hidden_states), [batch_size, num_output_chunks, frame_size],
[batch_size, num_output_chunks, frame_size], message=(
message=( "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension" f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}." ),
), )
)
chunked_hidden_states = tf.reshape( chunked_hidden_states = tf.reshape(
chunked_hidden_states, chunked_hidden_states,
...@@ -1391,16 +1379,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1391,16 +1379,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# compute attn scores # compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(global_attn_scores),
shape_list(global_attn_scores), [batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len], message=(
message=( "global_attn_scores have the wrong size. Size should be"
"global_attn_scores have the wrong size. Size should be" f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is" f" {shape_list(global_attn_scores)}."
f" {shape_list(global_attn_scores)}." ),
), )
)
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
global_attn_scores, global_attn_scores,
...@@ -1434,15 +1421,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1434,15 +1421,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# apply layer head masking # apply layer head masking
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(layer_head_mask),
shape_list(layer_head_mask), [self.num_heads],
[self.num_heads], message=(
message=( f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" f" {shape_list(layer_head_mask)}"
f" {shape_list(layer_head_mask)}" ),
), )
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
) )
...@@ -1456,16 +1442,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1456,16 +1442,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# global attn output # global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(global_attn_output),
shape_list(global_attn_output), [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], message=(
message=( "global_attn_output tensor has the wrong size. Size should be"
"global_attn_output tensor has the wrong size. Size should be" f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is" f" {shape_list(global_attn_output)}."
f" {shape_list(global_attn_output)}." ),
), )
)
global_attn_output = tf.reshape( global_attn_output = tf.reshape(
global_attn_output, global_attn_output,
......
...@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
) )
if tf.executing_eagerly(): # "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100" assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids) shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
...@@ -264,31 +263,25 @@ class TFMarianAttention(tf.keras.layers.Layer): ...@@ -264,31 +263,25 @@ class TFMarianAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -296,17 +289,14 @@ class TFMarianAttention(tf.keras.layers.Layer): ...@@ -296,17 +289,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -316,17 +306,14 @@ class TFMarianAttention(tf.keras.layers.Layer): ...@@ -316,17 +306,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -375,14 +362,11 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer): ...@@ -375,14 +362,11 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(hidden_states),
if tf.executing_eagerly(): shape_list(residual),
tf.debugging.assert_equal( message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
shape_list(hidden_states), )
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -801,9 +785,7 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -801,9 +785,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they if head_mask is not None:
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(head_mask)[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
...@@ -1009,10 +991,8 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -1009,10 +991,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: for attn_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
......
...@@ -232,31 +232,25 @@ class TFMBartAttention(tf.keras.layers.Layer): ...@@ -232,31 +232,25 @@ class TFMBartAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -264,17 +258,14 @@ class TFMBartAttention(tf.keras.layers.Layer): ...@@ -264,17 +258,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -284,17 +275,14 @@ class TFMBartAttention(tf.keras.layers.Layer): ...@@ -284,17 +275,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -343,14 +331,11 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer): ...@@ -343,14 +331,11 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(hidden_states),
if tf.executing_eagerly(): shape_list(residual),
tf.debugging.assert_equal( message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
shape_list(hidden_states), )
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -786,9 +771,7 @@ class TFMBartEncoder(tf.keras.layers.Layer): ...@@ -786,9 +771,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they if head_mask is not None:
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(head_mask)[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
...@@ -1001,10 +984,8 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -1001,10 +984,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
......
...@@ -206,31 +206,25 @@ class TFOPTAttention(tf.keras.layers.Layer): ...@@ -206,31 +206,25 @@ class TFOPTAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -238,17 +232,14 @@ class TFOPTAttention(tf.keras.layers.Layer): ...@@ -238,17 +232,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -258,17 +249,14 @@ class TFOPTAttention(tf.keras.layers.Layer): ...@@ -258,17 +249,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -664,10 +652,8 @@ class TFOPTDecoder(tf.keras.layers.Layer): ...@@ -664,10 +652,8 @@ class TFOPTDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask)]: for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
......
...@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
) )
if tf.executing_eagerly(): # "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100" assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids) shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
...@@ -265,31 +264,25 @@ class TFPegasusAttention(tf.keras.layers.Layer): ...@@ -265,31 +264,25 @@ class TFPegasusAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -297,17 +290,14 @@ class TFPegasusAttention(tf.keras.layers.Layer): ...@@ -297,17 +290,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -317,17 +307,14 @@ class TFPegasusAttention(tf.keras.layers.Layer): ...@@ -317,17 +307,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -377,14 +364,11 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer): ...@@ -377,14 +364,11 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(hidden_states),
if tf.executing_eagerly(): shape_list(residual),
tf.debugging.assert_equal( message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
shape_list(hidden_states), )
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -804,9 +788,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer): ...@@ -804,9 +788,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they if head_mask is not None:
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(head_mask)[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
...@@ -1015,10 +997,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -1015,10 +997,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
......
...@@ -74,13 +74,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -74,13 +74,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
) )
if tf.executing_eagerly(): # "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100" assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids) shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
...@@ -324,31 +323,25 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer): ...@@ -324,31 +323,25 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -356,17 +349,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer): ...@@ -356,17 +349,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -376,17 +366,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer): ...@@ -376,17 +366,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -434,14 +421,11 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer): ...@@ -434,14 +421,11 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
training=training, training=training,
) )
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(hidden_states),
if tf.executing_eagerly(): shape_list(residual),
tf.debugging.assert_equal( message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
shape_list(hidden_states), )
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -866,8 +850,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer): ...@@ -866,8 +850,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager. if head_mask is not None:
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(head_mask)[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
...@@ -1068,9 +1051,8 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer): ...@@ -1068,9 +1051,8 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
......
...@@ -161,13 +161,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -161,13 +161,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
) )
if tf.executing_eagerly(): # "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100" assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids) shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
......
...@@ -852,31 +852,25 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer): ...@@ -852,31 +852,25 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -884,17 +878,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer): ...@@ -884,17 +878,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -904,17 +895,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer): ...@@ -904,17 +895,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
......
...@@ -239,31 +239,25 @@ class TFXGLMAttention(tf.keras.layers.Layer): ...@@ -239,31 +239,25 @@ class TFXGLMAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attention_mask),
[bsz * self.num_heads, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=( message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}" f" {shape_list(attention_mask)}"
), ),
) )
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -271,17 +265,14 @@ class TFXGLMAttention(tf.keras.layers.Layer): ...@@ -271,17 +265,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=(
shape_list(layer_head_mask), f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
[self.num_heads], f" {shape_list(layer_head_mask)}"
message=( ),
f"Head mask for a single layer should be of size {(self.num_heads)}, but is" )
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -291,17 +282,14 @@ class TFXGLMAttention(tf.keras.layers.Layer): ...@@ -291,17 +282,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=(
shape_list(attn_output), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
[bsz * self.num_heads, tgt_len, self.head_dim], f" {shape_list(attn_output)}"
message=( ),
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" )
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -568,10 +556,8 @@ class TFXGLMMainLayer(tf.keras.layers.Layer): ...@@ -568,10 +556,8 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
......
...@@ -105,9 +105,9 @@ def get_masks(slen, lengths, causal, padding_mask=None): ...@@ -105,9 +105,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
# sanity check # sanity check
# assert shape_list(mask) == [bs, slen] # assert shape_list(mask) == [bs, slen]
if tf.executing_eagerly(): tf.debugging.assert_equal(shape_list(mask), [bs, slen])
tf.debugging.assert_equal(shape_list(mask), [bs, slen]) if causal:
assert causal is False or shape_list(attn_mask) == [bs, slen, slen] tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
return mask, attn_mask return mask, attn_mask
...@@ -384,10 +384,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -384,10 +384,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
if tf.executing_eagerly(): tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(lengths)[0], bs
shape_list(lengths)[0], bs ), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
...@@ -405,15 +404,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -405,15 +404,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
position_ids = tf.expand_dims(tf.range(slen), axis=0) position_ids = tf.expand_dims(tf.range(slen), axis=0)
position_ids = tf.tile(position_ids, (bs, 1)) position_ids = tf.tile(position_ids, (bs, 1))
if tf.executing_eagerly(): # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) tf.debugging.assert_equal(
tf.debugging.assert_equal( shape_list(position_ids), [bs, slen]
shape_list(position_ids), [bs, slen] ), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched" # position_ids = position_ids.transpose(0, 1)
# position_ids = position_ids.transpose(0, 1)
# langs # langs
if langs is not None and tf.executing_eagerly(): if langs is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(langs), [bs, slen] shape_list(langs), [bs, slen]
......
...@@ -1693,13 +1693,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -1693,13 +1693,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
) )
if tf.executing_eagerly(): # "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100" assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids) shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
...@@ -1837,24 +1836,18 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer): ...@@ -1837,24 +1836,18 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_weights),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, src_len],
tf.debugging.assert_equal( message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
shape_list(attn_weights), )
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None: if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attention_mask),
if tf.executing_eagerly(): [bsz, 1, tgt_len, src_len],
tf.debugging.assert_equal( message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
shape_list(attention_mask), )
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
...@@ -1862,14 +1855,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer): ...@@ -1862,14 +1855,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(layer_head_mask),
if tf.executing_eagerly(): [self.num_heads],
tf.debugging.assert_equal( message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
shape_list(layer_head_mask), )
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
...@@ -1880,14 +1870,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer): ...@@ -1880,14 +1870,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(attn_output),
if tf.executing_eagerly(): [bsz * self.num_heads, tgt_len, self.head_dim],
tf.debugging.assert_equal( message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
shape_list(attn_output), )
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose( attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
...@@ -1929,14 +1916,11 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer): ...@@ -1929,14 +1916,11 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
# The tf.debugging asserts are not compliant with XLA then they tf.debugging.assert_equal(
# have to be disabled in other modes than eager. shape_list(hidden_states),
if tf.executing_eagerly(): shape_list(residual),
tf.debugging.assert_equal( message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
shape_list(hidden_states), )
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they if head_mask is not None:
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(head_mask)[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
...@@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment