message=f"attn_probs should be of size ({batch_size}, {seq_len}, {num_heads}, {self._one_sided_attn_window_size*2+1}), but is of size {get_shape_list(attn_scores)}",
[batch_size,seq_len,self._num_heads,
self._one_sided_attn_window_size*2+1],
message=f"attn_probs should be of size "
f"({batch_size}, {seq_len}, {num_heads}, "
f"{self._one_sided_attn_window_size*2+1}),"
f" but is of size {get_shape_list(attn_scores)}",
)
# compute global attn indices required through out forward fn
...
...
@@ -303,7 +309,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size,frame_size,num_output_chunks]}, but got {get_shape_list(chunked_hidden_states)}.",
message=f"Make sure chunking is correctly applied. `Chunked hidden "
f"states should have output dimension"
f" {[batch_size,frame_size,num_output_chunks]}, but got "
f"{get_shape_list(chunked_hidden_states)}.",
)
chunked_hidden_states=tf.reshape(
...
...
@@ -738,19 +793,25 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
message=f"global_attn_scores have the wrong size. Size should be {(batch_size*self._num_heads,max_num_global_attn_indices,seq_len)}, but is {get_shape_list(global_attn_scores)}.",
message=f"global_attn_scores have the wrong size. Size should be"
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size*self._num_heads,max_num_global_attn_indices,self._key_dim)}, but is {get_shape_list(global_attn_output)}.",