Unverified Commit 34b43211 authored by mig-mfreitas's avatar mig-mfreitas Committed by GitHub
Browse files

Add YaRN and Dynamic-YaRN RoPE Scaling Methods (#30910)

* Add YaRN and Dynamic-YaRN RoPE Scaling Methods

YaRN (Yet another RoPE extension method) combines the NTK-By-Parts
Interpolation and Attention Scaling methods, improving upon existing
RoPE interpolation methods for longer context window sizes.

Fine-tuned models maintain their original performance across benchmarks
while enabling efficient extrapolation and transfer learning for
quicker convergence, especially in compute-limited environments.

We implement YaRN and Dynamic-YaRN for the following list of models:

 - LLaMA
 - Falcon
 - GPT-NeoX
 - Olmo
 - Persimmon
 - Phi
 - StableLM
 - OpenLLaMA

New unit tests are added to assert YaRN's correct behavior on both
short and long sequence inputs.

For more details, please refer to https://arxiv.org/abs/2309.00071

.
Co-authored-by: default avatarMiguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>

* Refactor YaRN implementation for LLaMA

Iterate on YaRN implementation for LLaMA and remove diff from remaining
models for increased PR modularity.

This commit includes the following changes:
- Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries
- Remove unnecessary attributes ('extrapolation_factor' and 'finetuned')
  from YaRN classes
- Inherit 'forward' method in YaRN classes from superclass
- Rename 'yarn' method to 'compute_yarn_scaling'
- Extend YaRN tests with further assertions
- Fix style inconsistencies
Co-authored-by: default avatarMiguel Monte e Freitas <miguelmontefreitas@tecnico.ulisboa.pt>

* Refactor Tensor Building Logic for YaRN

- Comply with the the tensor building logic introduced in #30743
- Add referencing to the optimized Attention Factor equation
- Remove Dynamic YaRN for a more agile deployment
Co-authored-by: default avatarmig-mfreitas <mig-mfreitas@users.noreply.github.com>

* remove unwanted file

---------
Co-authored-by: default avatarMiguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>
Co-authored-by: default avatarmig-mfreitas <mig-mfreitas@users.noreply.github.com>
Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
parent 7405c1c7
...@@ -283,7 +283,6 @@ class FalconAttention(nn.Module): ...@@ -283,7 +283,6 @@ class FalconAttention(nn.Module):
self.attention_dropout = nn.Dropout(config.attention_dropout) self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon
def _init_rope(self): def _init_rope(self):
if self.config.rope_scaling is None: if self.config.rope_scaling is None:
self.rotary_emb = FalconRotaryEmbedding( self.rotary_emb = FalconRotaryEmbedding(
......
...@@ -188,7 +188,6 @@ class FuyuConfig(PretrainedConfig): ...@@ -188,7 +188,6 @@ class FuyuConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.
......
...@@ -154,7 +154,6 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -154,7 +154,6 @@ class GPTNeoXConfig(PretrainedConfig):
"The hidden size is not divisble by the number of attention heads! Make sure to update them!" "The hidden size is not divisble by the number of attention heads! Make sure to update them!"
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.
......
...@@ -84,13 +84,22 @@ class LlamaConfig(PretrainedConfig): ...@@ -84,13 +84,22 @@ class LlamaConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is strategies: linear, dynamic and yarn. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave: these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions. experimental feature, subject to breaking API changes in future versions.
For the `yarn` strategy, the dictionary may also contain the following fields:
`original_max_position_embeddings` (`int`, *optional*):
The original maximum sequence length. This is used to scale the RoPE embeddings.
`attention_factor` (`float`, *optional*):
The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the `original_max_position_embeddings/max_position_embeddings` ratio.
`beta_fast` (`float`, *optional*):
Parameter to set the boundary for extrapolation (only) in the linear ramp function.
`beta_slow` (`float`, *optional*):
Parameter to set the boundary for interpolation (only) in the linear ramp function.
attention_bias (`bool`, *optional*, defaults to `False`): attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
...@@ -178,15 +187,52 @@ class LlamaConfig(PretrainedConfig): ...@@ -178,15 +187,52 @@ class LlamaConfig(PretrainedConfig):
if self.rope_scaling is None: if self.rope_scaling is None:
return return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) < 2:
raise ValueError( raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" "`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
) )
rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None) rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]:
raise ValueError( raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}"
) )
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
if rope_scaling_type != "yarn":
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6:
raise ValueError(
"`rope_scaling` with type "
f"{rope_scaling_type}"
" must be a dictionary with a maximum of six fields, `type`, `factor`,"
"`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, "
f"got {self.rope_scaling}"
)
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
attention_factor = self.rope_scaling.get("attention_factor", None)
beta_fast = self.rope_scaling.get("beta_fast", None)
beta_slow = self.rope_scaling.get("beta_slow", None)
if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}"
)
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
b_fast = beta_fast if beta_fast is not None else 32
b_slow = beta_slow if beta_slow is not None else 1
if b_fast < b_slow:
raise ValueError(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}"
)
...@@ -132,6 +132,77 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): ...@@ -132,6 +132,77 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
return cos, sin return cos, sin
class LlamaYarnScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1,
original_max_position_embeddings=2048,
attention_factor=None,
beta_fast=32,
beta_slow=1,
device=None,
):
super().__init__(dim, max_position_embeddings, base, device, scaling_factor)
self.original_max_position_embeddings = original_max_position_embeddings
self.attention_factor = attention_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
if self.attention_factor is None:
# Recommended attention factor for LLaMA models.
# For more details please refer to https://arxiv.org/pdf/2309.00071, Eq. 22.
self.attention_factor = 0.1 * math.log(scaling_factor) + 1.0
self.compute_yarn_scaling(device)
# Inverse dimension formula to find the dimension based on the number of rotations
def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
# Find dimension range bounds based on rotations
def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)
def linear_ramp_mask(self, min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def forward(self, x, position_ids=None):
# Difference to the original RoPE: applies a scaling factor computed with
# the YaRN method (NTK-by-Parts + Attn Scaling)
# x: [bs, num_attention_heads, seq_len, head_size]
cos, sin = super().forward(x, position_ids)
cos = cos * self.mscale
sin = sin * self.mscale
return cos, sin
def compute_yarn_scaling(self, device):
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs)
low, high = self.find_correction_range(
self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask = 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device)
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
self.register_buffer("inv_freq", inv_freq)
# Get n-dimensional magnitude scaling corrected for interpolation
self.mscale = self.attention_factor
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
...@@ -258,6 +329,15 @@ class LlamaAttention(nn.Module): ...@@ -258,6 +329,15 @@ class LlamaAttention(nn.Module):
else: else:
scaling_type = self.config.rope_scaling["type"] scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"] scaling_factor = self.config.rope_scaling["factor"]
# Yarn parameters
kwargs = {
"dim": self.config.rope_scaling.get("original_max_position_embeddings", None),
"max_position_embeddings": self.config.rope_scaling.get("attention_factor", None),
"base": self.config.rope_scaling.get("beta_fast", None),
"scaling_factor": self.config.rope_scaling.get("beta_slow", None),
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if scaling_type == "linear": if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim, self.head_dim,
...@@ -272,6 +352,14 @@ class LlamaAttention(nn.Module): ...@@ -272,6 +352,14 @@ class LlamaAttention(nn.Module):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
base=self.rope_theta, base=self.rope_theta,
) )
elif scaling_type == "yarn":
self.rotary_emb = LlamaYarnScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**kwargs,
)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
......
...@@ -160,7 +160,6 @@ class OlmoConfig(PretrainedConfig): ...@@ -160,7 +160,6 @@ class OlmoConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.
......
...@@ -236,7 +236,6 @@ class OlmoAttention(nn.Module): ...@@ -236,7 +236,6 @@ class OlmoAttention(nn.Module):
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope() self._init_rope()
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Olmo
def _init_rope(self): def _init_rope(self):
if self.config.rope_scaling is None: if self.config.rope_scaling is None:
self.rotary_emb = OlmoRotaryEmbedding( self.rotary_emb = OlmoRotaryEmbedding(
......
...@@ -138,7 +138,6 @@ class PersimmonConfig(PretrainedConfig): ...@@ -138,7 +138,6 @@ class PersimmonConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.
......
...@@ -165,7 +165,6 @@ class PhiConfig(PretrainedConfig): ...@@ -165,7 +165,6 @@ class PhiConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.
......
...@@ -164,7 +164,6 @@ class StableLmConfig(PretrainedConfig): ...@@ -164,7 +164,6 @@ class StableLmConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.
......
...@@ -55,6 +55,7 @@ if is_torch_available(): ...@@ -55,6 +55,7 @@ if is_torch_available():
LlamaDynamicNTKScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding, LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding, LlamaRotaryEmbedding,
LlamaYarnScalingRotaryEmbedding,
) )
...@@ -397,7 +398,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -397,7 +398,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
@parameterized.expand([("linear",), ("dynamic",)]) @parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
def test_model_rope_scaling_from_config(self, scaling_type): def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size) short_input = ids_tensor([1, 10], config.vocab_size)
...@@ -491,6 +492,26 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -491,6 +492,26 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
torch.testing.assert_close(ntk_sin_long, original_sin_long) torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
# Sanity check Yarn RoPE scaling
yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_short, original_cos_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@require_bitsandbytes @require_bitsandbytes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment