Unverified Commit 83f9196c authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`GPTNeoX`] Fix BC issue with 4.36 (#28602)

* fix dtype issue

* add a test

* update copied from mentions

* nits

* fixup

* fix copies

* Apply suggestions from code review
parent 3f69f415
...@@ -526,8 +526,8 @@ def attention_mask_func(attention_scores, ltor_mask): ...@@ -526,8 +526,8 @@ def attention_mask_func(attention_scores, ltor_mask):
return attention_scores return attention_scores
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with LlamaRotary->GPTNeoXRotary
class GPTNeoXRotaryEmbedding(nn.Module): class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -549,8 +549,8 @@ class GPTNeoXRotaryEmbedding(nn.Module): ...@@ -549,8 +549,8 @@ class GPTNeoXRotaryEmbedding(nn.Module):
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -558,15 +558,15 @@ class GPTNeoXRotaryEmbedding(nn.Module): ...@@ -558,15 +558,15 @@ class GPTNeoXRotaryEmbedding(nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return ( return (
self.cos_cached[:seq_len].to(dtype=x.dtype), self.cos_cached[:seq_len],
self.sin_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len],
) )
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->GPTNeoX
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device) super().__init__(dim, max_position_embeddings, base, device)
...@@ -579,14 +579,14 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): ...@@ -579,14 +579,14 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False)
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device) super().__init__(dim, max_position_embeddings, base, device)
...@@ -606,8 +606,8 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): ...@@ -606,8 +606,8 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False)
def rotate_half(x): def rotate_half(x):
......
...@@ -235,6 +235,7 @@ class GPTNeoXJapaneseAttention(nn.Module): ...@@ -235,6 +235,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -256,8 +257,8 @@ class RotaryEmbedding(nn.Module): ...@@ -256,8 +257,8 @@ class RotaryEmbedding(nn.Module):
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
...@@ -265,8 +266,8 @@ class RotaryEmbedding(nn.Module): ...@@ -265,8 +266,8 @@ class RotaryEmbedding(nn.Module):
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return ( return (
self.cos_cached[:seq_len].to(dtype=x.dtype), self.cos_cached[:seq_len],
self.sin_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len],
) )
......
...@@ -355,3 +355,13 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase): ...@@ -355,3 +355,13 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase):
output_str = tokenizer.batch_decode(output_ids)[0] output_str = tokenizer.batch_decode(output_ids)[0]
self.assertEqual(output_str, expected_output) self.assertEqual(output_str, expected_output)
def pythia_integration_test(self):
model_name_or_path = "EleutherAI/pythia-70m"
model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device)
EXPECTED_LOGITS = torch.tensor([1069.0000, 228.7500, 1072.0000, 1072.0000, 1069.0000, 1068.0000, 1068.0000, 1071.0000, 1071.0000, 1071.0000, 1073.0000, 1070.0000, 1071.0000, 1075.0000, 1073.0000, 1075.0000, 1074.0000, 1069.0000, 1072.0000, 1071.0000, 1071.0000, 1071.0000, 1070.0000, 1069.0000, 1069.0000, 1069.0000, 1070.0000, 1075.0000, 1073.0000, 1074.0000]) # fmt: skip
input_ids = [29, 93, 303, 64, 5478, 49651, 10394, 187, 34, 12939, 875]
# alternative: tokenizer('<|im_start|>system\nA chat between')
input_ids = torch.as_tensor(input_ids)[None].to(torch_device)
outputs = model(input_ids)["logits"][:, -1][0, :30]
self.assertTrue(torch.allclose(EXPECTED_LOGITS, outputs, atol=1e-5))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment