Unverified Commit c5933c9c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Bug fix] Fix batch size attention head size mismatch (#3214)

parent 91a2a80e
...@@ -86,8 +86,10 @@ class AttentionBlock(nn.Module): ...@@ -86,8 +86,10 @@ class AttentionBlock(nn.Module):
head_size = self.num_heads head_size = self.num_heads
if unmerge_head_and_batch: if unmerge_head_and_batch:
batch_size, seq_len, dim = tensor.shape batch_head_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) batch_size = batch_head_size // head_size
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
else: else:
batch_size, _, seq_len, dim = tensor.shape batch_size, _, seq_len, dim = tensor.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment