message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size*2+1}), but is of size {shape_list(attn_scores)}",
message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size*2+1}), but is of size {shape_list(attn_scores)}",
)
# compute global attn indices required through out forward fn
# compute global attn indices required through out forward fn
(
(
...
@@ -803,16 +804,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -803,16 +804,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size,frame_size,num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
[batch_size,num_output_chunks,frame_size],
)
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size,frame_size,num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
)
chunked_hidden_states=tf.reshape(
chunked_hidden_states=tf.reshape(
chunked_hidden_states,
chunked_hidden_states,
...
@@ -1175,7 +1190,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1175,7 +1190,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
message=f"global_attn_scores have the wrong size. Size should be {(batch_size*self.num_heads,max_num_global_attn_indices,seq_len)}, but is {shape_list(global_attn_scores)}.",
message=f"global_attn_scores have the wrong size. Size should be {(batch_size*self.num_heads,max_num_global_attn_indices,seq_len)}, but is {shape_list(global_attn_scores)}.",
)
global_attn_scores=tf.reshape(
global_attn_scores=tf.reshape(
global_attn_scores,
global_attn_scores,
...
@@ -1346,6 +1366,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1346,6 +1366,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size*self.num_heads,max_num_global_attn_indices,self.head_dim)}, but is {shape_list(global_attn_output)}.",
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size*self.num_heads,max_num_global_attn_indices,self.head_dim)}, but is {shape_list(global_attn_output)}.",
)
global_attn_output=tf.reshape(
global_attn_output=tf.reshape(
global_attn_output,
global_attn_output,
...
@@ -2230,6 +2253,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
...
@@ -2230,6 +2253,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
logger.info("Initializing global attention on question tokens...")
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
# put global attention on all tokens until `config.sep_token_id` is reached