Unverified Commit e5dd7432 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Remove unnecessary aten::to overhead in llama (#24203)

* fix dtype init

* fix copies

* fix fixcopies mess

* edit forward as well

* copy
parent 4fe9716a
...@@ -100,8 +100,9 @@ class LlamaRotaryEmbedding(torch.nn.Module): ...@@ -100,8 +100,9 @@ class LlamaRotaryEmbedding(torch.nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) dtype = torch.get_default_dtype()
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -112,8 +113,8 @@ class LlamaRotaryEmbedding(torch.nn.Module): ...@@ -112,8 +113,8 @@ class LlamaRotaryEmbedding(torch.nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
......
...@@ -111,8 +111,9 @@ class OpenLlamaRotaryEmbedding(torch.nn.Module): ...@@ -111,8 +111,9 @@ class OpenLlamaRotaryEmbedding(torch.nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) dtype = torch.get_default_dtype()
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -123,8 +124,8 @@ class OpenLlamaRotaryEmbedding(torch.nn.Module): ...@@ -123,8 +124,8 @@ class OpenLlamaRotaryEmbedding(torch.nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment