Unverified Commit f8c43c94 authored by Kunhao ZHENG's avatar Kunhao ZHENG Committed by GitHub
Browse files

Fix squeeze into torch 1.x compatible form in llama model (#22808)

fix-squeeze-tuple
parent 5269718c
...@@ -132,8 +132,8 @@ def rotate_half(x): ...@@ -132,8 +132,8 @@ def rotate_half(x):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze((0, 1)) # [seq_len, dim] cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze((0, 1)) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment