Unverified Commit 689b1abb authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

fix EleutherAI/gpt-neox-20b does not work in tgi (#2346)


Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 82d19d77
......@@ -153,8 +153,16 @@ class FlashNeoxAttention(torch.nn.Module):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
# Compute rotary embeddings on rotary_ndims
query_rot = qkv[:, 0][..., : self.rotary_dim]
query_pass = qkv[:, 0][..., self.rotary_dim :]
key_rot = qkv[:, 1][..., : self.rotary_dim]
key_pass = qkv[:, 1][..., self.rotary_dim :]
# Inplace rotary
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
self.rotary_emb(query_rot, key_rot, cos, sin)
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment