message=f"attn_probs should be of size ({batch_size}, {seq_len}, {num_heads}, {self._one_sided_attn_window_size*2+1}), but is of size {shape_list(attn_scores)}",
)
# compute global attn indices required through out forward fn
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size,frame_size,num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
message=f"global_attn_scores have the wrong size. Size should be {(batch_size*self._num_heads,max_num_global_attn_indices,seq_len)}, but is {shape_list(global_attn_scores)}.",
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size*self._num_heads,max_num_global_attn_indices,self._key_dim)}, but is {shape_list(global_attn_output)}.",