"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "5fd4fd2870b45a41c2d0cd4d2b6b4ca8434c4bd2"
Unverified Commit 08b7e4a2 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix flash neox rotary embeddings (#150)

parent 610bb1f9
......@@ -319,12 +319,12 @@ class FlashNeoxAttention(torch.nn.Module):
layer_past[...] = qkv_rot[:, 1:]
# output
attn_output = torch.empty_like(qkv[:, 0])
attn_output = torch.empty_like(qkv_rot[:, 0])
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
qkv_rot[:, 0],
qkv_rot[:, 1],
qkv_rot[:, 2],
attn_output,
cu_seqlens,
cu_seqlens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment