"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b7345d22d0b59ccfda8df840a918af33cf95a189"
Unverified Commit 626c1b8a authored by fpgaminer's avatar fpgaminer Committed by GitHub
Browse files

improve(llama): Faster apply_rotary_pos_emb (#22785)

parent abbc96a2
...@@ -131,10 +131,11 @@ def rotate_half(x): ...@@ -131,10 +131,11 @@ def rotate_half(x):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) cos = cos.squeeze((0, 1)) # [seq_len, dim]
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) sin = sin.squeeze((0, 1)) # [seq_len, dim]
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment