Unverified Commit 441de62f authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

RoPE models: add numerical sanity-check test for RoPE scaling (#29808)

* add hard rope scaling test

* make fixup

* quick rope scaling tests

* add copy statements
parent aac7099c
......@@ -45,6 +45,11 @@ if is_torch_available():
FalconForTokenClassification,
FalconModel,
)
from transformers.models.falcon.modeling_falcon import (
FalconDynamicNTKScalingRotaryEmbedding,
FalconLinearScalingRotaryEmbedding,
FalconRotaryEmbedding,
)
class FalconModelTester:
......@@ -408,7 +413,8 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
)
@parameterized.expand([("linear",), ("dynamic",)])
def test_model_rope_scaling(self, scaling_type):
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
......@@ -438,6 +444,65 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
# Sanity check original RoPE
original_rope = FalconRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length)
original_cos_long, original_sin_long = original_rope(x, long_input_length)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = FalconLinearScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = FalconDynamicNTKScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
......
......@@ -38,6 +38,11 @@ if is_torch_available():
GPTNeoXForTokenClassification,
GPTNeoXModel,
)
from transformers.models.gpt_neox.modeling_gpt_neox import (
GPTNeoXDynamicNTKScalingRotaryEmbedding,
GPTNeoXLinearScalingRotaryEmbedding,
GPTNeoXRotaryEmbedding,
)
class GPTNeoXModelTester:
......@@ -301,7 +306,8 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
pass
@parameterized.expand([("linear",), ("dynamic",)])
def test_model_rope_scaling(self, scaling_type):
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->GPTNeoX
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
......@@ -331,6 +337,66 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
# Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->GPTNeoX, rope_theta->rotary_emb_base
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
# Sanity check original RoPE
original_rope = GPTNeoXRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rotary_emb_base,
).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length)
original_cos_long, original_sin_long = original_rope(x, long_input_length)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = GPTNeoXLinearScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rotary_emb_base,
scaling_factor=scaling_factor,
).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = GPTNeoXDynamicNTKScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rotary_emb_base,
scaling_factor=scaling_factor,
).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
@require_torch
class GPTNeoXLanguageGenerationTest(unittest.TestCase):
......
......@@ -51,6 +51,11 @@ if is_torch_available():
LlamaModel,
LlamaTokenizer,
)
from transformers.models.llama.modeling_llama import (
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding,
)
class LlamaModelTester:
......@@ -370,7 +375,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
pass
@parameterized.expand([("linear",), ("dynamic",)])
def test_model_rope_scaling(self, scaling_type):
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
......@@ -400,6 +405,69 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE
original_rope = LlamaRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = LlamaLinearScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = LlamaDynamicNTKScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
......
......@@ -45,6 +45,11 @@ if is_torch_available():
PersimmonForSequenceClassification,
PersimmonModel,
)
from transformers.models.persimmon.modeling_persimmon import (
PersimmonDynamicNTKScalingRotaryEmbedding,
PersimmonLinearScalingRotaryEmbedding,
PersimmonRotaryEmbedding,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon
......@@ -365,8 +370,8 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
pass
@parameterized.expand([("linear",), ("dynamic",)])
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling with Llama->Persimmon
def test_model_rope_scaling(self, scaling_type):
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Persimmon
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
......@@ -396,6 +401,66 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
# Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Persimmon
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
# Sanity check original RoPE
original_rope = PersimmonRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length)
original_cos_long, original_sin_long = original_rope(x, long_input_length)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = PersimmonLinearScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = PersimmonDynamicNTKScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
@require_torch
class PersimmonIntegrationTest(unittest.TestCase):
......
......@@ -19,8 +19,9 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import PhiConfig, is_torch_available
from transformers import PhiConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
......@@ -46,6 +47,11 @@ if is_torch_available():
PhiForTokenClassification,
PhiModel,
)
from transformers.models.phi.modeling_phi import (
PhiDynamicNTKScalingRotaryEmbedding,
PhiLinearScalingRotaryEmbedding,
PhiRotaryEmbedding,
)
class PhiModelTester:
......@@ -360,6 +366,98 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@parameterized.expand([("linear",), ("dynamic",)])
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Phi
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
set_seed(42) # Fixed seed at init time so the two models get the same random weights
original_model = PhiModel(config)
original_model.to(torch_device)
original_model.eval()
original_short_output = original_model(short_input).last_hidden_state
original_long_output = original_model(long_input).last_hidden_state
set_seed(42) # Fixed seed at init time so the two models get the same random weights
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
scaled_model = PhiModel(config)
scaled_model.to(torch_device)
scaled_model.eval()
scaled_short_output = scaled_model(short_input).last_hidden_state
scaled_long_output = scaled_model(long_input).last_hidden_state
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# maximum sequence length, so the outputs for the short input should match.
if scaling_type == "dynamic":
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
else:
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
# Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Phi
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
# Sanity check original RoPE
original_rope = PhiRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length)
original_cos_long, original_sin_long = original_rope(x, long_input_length)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = PhiLinearScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = PhiDynamicNTKScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
......
......@@ -44,6 +44,11 @@ if is_torch_available():
StableLmForSequenceClassification,
StableLmModel,
)
from transformers.models.stablelm.modeling_stablelm import (
StableLmDynamicNTKScalingRotaryEmbedding,
StableLmLinearScalingRotaryEmbedding,
StableLmRotaryEmbedding,
)
# Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm
......@@ -351,7 +356,8 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@parameterized.expand([("linear",), ("dynamic",)])
def test_model_rope_scaling(self, scaling_type):
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
......@@ -381,6 +387,66 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
# Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->StableLm
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
# Sanity check original RoPE
original_rope = StableLmRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length)
original_cos_long, original_sin_long = original_rope(x, long_input_length)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = StableLmLinearScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = StableLmDynamicNTKScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
@require_torch
class StableLmModelIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment