Commit 841a2313 authored by s4rduk4r's avatar s4rduk4r
Browse files

Fix apply_rotary_emb() to have both tensors on the same device

parent ffaaa259
...@@ -36,7 +36,7 @@ def apply_rotary_emb( ...@@ -36,7 +36,7 @@ def apply_rotary_emb(
xk_ = torch.view_as_complex( xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous() xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
) )
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk) return xq_out.type_as(xq), xk_out.type_as(xk)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment