Unverified Commit 9179605e authored by Mario928's avatar Mario928 Committed by GitHub
Browse files

Fix: Replace view() with reshape() in neox_modeling.py to resolve RuntimeError (#1155)

parent 7402a355
......@@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module):
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
query = query.view(
query = query.reshape(
batch_size * num_attention_heads, query_length, attn_head_size
)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
1,
dtype=query.dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment