"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8d2fca07e85af51f50e297d14e99318c1f665a9c"
Unverified Commit cf601b90 authored by Guangyuan Ma's avatar Guangyuan Ma Committed by GitHub
Browse files

Fix Unnecessary move of tensors from CPU to GPU in LlamaRotaryEmbedding (#22234)

push
parent bec07561
...@@ -99,8 +99,8 @@ class LlamaRotaryEmbedding(torch.nn.Module): ...@@ -99,8 +99,8 @@ class LlamaRotaryEmbedding(torch.nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :] self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.sin_cached = emb.sin()[None, None, :, :] self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -111,11 +111,11 @@ class LlamaRotaryEmbedding(torch.nn.Module): ...@@ -111,11 +111,11 @@ class LlamaRotaryEmbedding(torch.nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :].to(dtype=x.dtype) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.sin_cached = emb.sin()[None, None, :, :].to(dtype=x.dtype) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment