Unverified Commit cf601b90 authored by Guangyuan Ma's avatar Guangyuan Ma Committed by GitHub
Browse files

Fix Unnecessary move of tensors from CPU to GPU in LlamaRotaryEmbedding (#22234)

push
parent bec07561
......@@ -99,8 +99,8 @@ class LlamaRotaryEmbedding(torch.nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
......@@ -111,11 +111,11 @@ class LlamaRotaryEmbedding(torch.nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :].to(dtype=x.dtype)
self.sin_cached = emb.sin()[None, None, :, :].to(dtype=x.dtype)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device),
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment