message=f"attn_probs should be of size ({batch_size}, {seq_len}, {num_heads}, {self._one_sided_attn_window_size*2+1}), but is of size {shape_list(attn_scores)}",
message=f"attn_probs should be of size ({batch_size}, {seq_len}, {num_heads}, {self._one_sided_attn_window_size*2+1}), but is of size {get_shape_list(attn_scores)}",
)
)
# compute global attn indices required through out forward fn
# compute global attn indices required through out forward fn
...
@@ -356,7 +337,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -356,7 +337,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# define frame size and frame stride (similar to convolution)
# define frame size and frame stride (similar to convolution)
...
@@ -735,9 +716,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -735,9 +716,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
iftf.executing_eagerly():
iftf.executing_eagerly():
tf.debugging.assert_equal(
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
get_shape_list(chunked_hidden_states),
[batch_size,num_output_chunks,frame_size],
[batch_size,num_output_chunks,frame_size],
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size,frame_size,num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size,frame_size,num_output_chunks]}, but got {get_shape_list(chunked_hidden_states)}.",
)
)
chunked_hidden_states=tf.reshape(
chunked_hidden_states=tf.reshape(
...
@@ -752,7 +733,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -752,7 +733,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
"""compute global attn indices required throughout forward pass"""
"""compute global attn indices required throughout forward pass"""
# All global attention size are fixed through global_attention_size
# All global attention size are fixed through global_attention_size
message=f"global_attn_scores have the wrong size. Size should be {(batch_size*self._num_heads,max_num_global_attn_indices,seq_len)}, but is {shape_list(global_attn_scores)}.",
message=f"global_attn_scores have the wrong size. Size should be {(batch_size*self._num_heads,max_num_global_attn_indices,seq_len)}, but is {get_shape_list(global_attn_scores)}.",
)
)
global_attn_scores=tf.reshape(
global_attn_scores=tf.reshape(
...
@@ -922,8 +903,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -922,8 +903,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size*self._num_heads,max_num_global_attn_indices,self._key_dim)}, but is {shape_list(global_attn_output)}.",
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size*self._num_heads,max_num_global_attn_indices,self._key_dim)}, but is {get_shape_list(global_attn_output)}.",
)
)
global_attn_output=tf.reshape(
global_attn_output=tf.reshape(
...
@@ -987,7 +968,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -987,7 +968,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):