Unverified Commit 64845307 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Remove unnecessary unsqueeze - squeeze in rotary positional embedding (#26162)

* remove unnecessary unsqueeze-squeeze in llama

* correct other models

* fix

* revert gpt_neox_japanese

* fix copie

* fix test
parent 65aabafe
...@@ -118,8 +118,8 @@ class OpenLlamaRotaryEmbedding(nn.Module): ...@@ -118,8 +118,8 @@ class OpenLlamaRotaryEmbedding(nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -127,8 +127,8 @@ class OpenLlamaRotaryEmbedding(nn.Module): ...@@ -127,8 +127,8 @@ class OpenLlamaRotaryEmbedding(nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype),
) )
...@@ -148,8 +148,8 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): ...@@ -148,8 +148,8 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama
...@@ -175,8 +175,8 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): ...@@ -175,8 +175,8 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def rotate_half(x): def rotate_half(x):
...@@ -188,11 +188,8 @@ def rotate_half(x): ...@@ -188,11 +188,8 @@ def rotate_half(x):
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin[position_ids].unsqueeze(1)
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
......
...@@ -115,8 +115,8 @@ class FalconRotaryEmbedding(nn.Module): ...@@ -115,8 +115,8 @@ class FalconRotaryEmbedding(nn.Module):
if dtype in [torch.float16, torch.bfloat16]: if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float() emb = emb.float()
self.cos_cached = emb.cos()[None, :, :] self.cos_cached = emb.cos()
self.sin_cached = emb.sin()[None, :, :] self.sin_cached = emb.sin()
self.cos_cached = self.cos_cached.type(dtype) self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype)
...@@ -133,8 +133,8 @@ class FalconRotaryEmbedding(nn.Module): ...@@ -133,8 +133,8 @@ class FalconRotaryEmbedding(nn.Module):
self.sin_cached = self.sin_cached.to(device) self.sin_cached = self.sin_cached.to(device)
# Gather cos, sin at the designated position ids # Gather cos, sin at the designated position ids
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] cos = self.cos_cached[position_ids] # [bs, seq_len, dim]
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] sin = self.sin_cached[position_ids] # [bs, seq_len, dim]
return cos, sin return cos, sin
def forward(self, query, key, past_key_values_length, position_ids): def forward(self, query, key, past_key_values_length, position_ids):
...@@ -181,8 +181,8 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): ...@@ -181,8 +181,8 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
if dtype in [torch.float16, torch.bfloat16]: if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float() emb = emb.float()
self.cos_cached = emb.cos()[None, :, :] self.cos_cached = emb.cos()
self.sin_cached = emb.sin()[None, :, :] self.sin_cached = emb.sin()
self.cos_cached = self.cos_cached.type(dtype) self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype)
...@@ -215,8 +215,8 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): ...@@ -215,8 +215,8 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
if dtype in [torch.float16, torch.bfloat16]: if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float() emb = emb.float()
self.cos_cached = emb.cos()[None, :, :] self.cos_cached = emb.cos()
self.sin_cached = emb.sin()[None, :, :] self.sin_cached = emb.sin()
self.cos_cached = self.cos_cached.type(dtype) self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype)
......
...@@ -309,8 +309,8 @@ class GPTNeoXRotaryEmbedding(nn.Module): ...@@ -309,8 +309,8 @@ class GPTNeoXRotaryEmbedding(nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -318,8 +318,8 @@ class GPTNeoXRotaryEmbedding(nn.Module): ...@@ -318,8 +318,8 @@ class GPTNeoXRotaryEmbedding(nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype),
) )
...@@ -339,8 +339,8 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): ...@@ -339,8 +339,8 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX
...@@ -366,8 +366,8 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): ...@@ -366,8 +366,8 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def rotate_half(x): def rotate_half(x):
...@@ -379,11 +379,8 @@ def rotate_half(x): ...@@ -379,11 +379,8 @@ def rotate_half(x):
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin[position_ids].unsqueeze(1)
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
......
...@@ -261,8 +261,8 @@ class RotaryEmbedding(nn.Module): ...@@ -261,8 +261,8 @@ class RotaryEmbedding(nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -270,8 +270,8 @@ class RotaryEmbedding(nn.Module): ...@@ -270,8 +270,8 @@ class RotaryEmbedding(nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype),
) )
......
...@@ -523,8 +523,8 @@ class IdeficsEmbedding(torch.nn.Module): ...@@ -523,8 +523,8 @@ class IdeficsEmbedding(torch.nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -532,8 +532,8 @@ class IdeficsEmbedding(torch.nn.Module): ...@@ -532,8 +532,8 @@ class IdeficsEmbedding(torch.nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype),
) )
...@@ -546,11 +546,8 @@ def rotate_half(x): ...@@ -546,11 +546,8 @@ def rotate_half(x):
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin[position_ids].unsqueeze(1)
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
......
...@@ -138,8 +138,8 @@ class LlamaRotaryEmbedding(nn.Module): ...@@ -138,8 +138,8 @@ class LlamaRotaryEmbedding(nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -147,8 +147,8 @@ class LlamaRotaryEmbedding(nn.Module): ...@@ -147,8 +147,8 @@ class LlamaRotaryEmbedding(nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype),
) )
...@@ -167,8 +167,8 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): ...@@ -167,8 +167,8 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
...@@ -193,8 +193,8 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): ...@@ -193,8 +193,8 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def rotate_half(x): def rotate_half(x):
...@@ -204,12 +204,10 @@ def rotate_half(x): ...@@ -204,12 +204,10 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin[position_ids].unsqueeze(1)
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
......
...@@ -149,8 +149,8 @@ class MistralRotaryEmbedding(nn.Module): ...@@ -149,8 +149,8 @@ class MistralRotaryEmbedding(nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -158,8 +158,8 @@ class MistralRotaryEmbedding(nn.Module): ...@@ -158,8 +158,8 @@ class MistralRotaryEmbedding(nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype),
) )
...@@ -173,11 +173,8 @@ def rotate_half(x): ...@@ -173,11 +173,8 @@ def rotate_half(x):
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin[position_ids].unsqueeze(1)
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
......
...@@ -94,8 +94,8 @@ class PersimmonRotaryEmbedding(nn.Module): ...@@ -94,8 +94,8 @@ class PersimmonRotaryEmbedding(nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -103,8 +103,8 @@ class PersimmonRotaryEmbedding(nn.Module): ...@@ -103,8 +103,8 @@ class PersimmonRotaryEmbedding(nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype),
) )
...@@ -124,8 +124,8 @@ class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding): ...@@ -124,8 +124,8 @@ class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon
...@@ -151,8 +151,8 @@ class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding): ...@@ -151,8 +151,8 @@ class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
...@@ -165,11 +165,8 @@ def rotate_half(x): ...@@ -165,11 +165,8 @@ def rotate_half(x):
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin[position_ids].unsqueeze(1)
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
......
...@@ -430,7 +430,8 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -430,7 +430,8 @@ class MistralIntegrationTest(unittest.TestCase):
def test_model_7b_logits(self): def test_model_7b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto") model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")
out = model(torch.tensor([input_ids])).logits input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
out = model(input_ids).logits.cpu()
# Expected mean on dim = -1 # Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]]) EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment