"src/vscode:/vscode.git/clone" did not exist on "d1222064669a758c476014fb7a09e24a0c907222"
Unverified Commit c5933c9c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Bug fix] Fix batch size attention head size mismatch (#3214)

parent 91a2a80e
......@@ -86,8 +86,10 @@ class AttentionBlock(nn.Module):
head_size = self.num_heads
if unmerge_head_and_batch:
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
batch_head_size, seq_len, dim = tensor.shape
batch_size = batch_head_size // head_size
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
else:
batch_size, _, seq_len, dim = tensor.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment