"src/vscode:/vscode.git/clone" did not exist on "4903d3cc9deeece7b28024971d1279f4f085d83e"
Unverified Commit 9179605e authored by Mario928's avatar Mario928 Committed by GitHub
Browse files

Fix: Replace view() with reshape() in neox_modeling.py to resolve RuntimeError (#1155)

parent 7402a355
...@@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module): ...@@ -283,10 +283,10 @@ class GPTNeoXAttention(nn.Module):
batch_size, num_attention_heads, query_length, attn_head_size = query.size() batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2) key_length = key.size(-2)
query = query.view( query = query.reshape(
batch_size * num_attention_heads, query_length, attn_head_size batch_size * num_attention_heads, query_length, attn_head_size
) )
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros( attn_scores = torch.zeros(
1, 1,
dtype=query.dtype, dtype=query.dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment