Unverified Commit 2e113422 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files
parent 5a4a76ed
...@@ -1295,6 +1295,7 @@ else: ...@@ -1295,6 +1295,7 @@ else:
) )
_import_structure["modeling_flash_attention_utils"] = [] _import_structure["modeling_flash_attention_utils"] = []
_import_structure["modeling_outputs"] = [] _import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
_import_structure["modeling_utils"] = ["PreTrainedModel"] _import_structure["modeling_utils"] = ["PreTrainedModel"]
# PyTorch models structure # PyTorch models structure
...@@ -6010,6 +6011,7 @@ if TYPE_CHECKING: ...@@ -6010,6 +6011,7 @@ if TYPE_CHECKING:
WatermarkLogitsProcessor, WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor, WhisperTimeStampLogitsProcessor,
) )
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .models.albert import ( from .models.albert import (
AlbertForMaskedLM, AlbertForMaskedLM,
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
from .configuration_utils import PretrainedConfig
from .utils import is_torch_available, logging
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
def _compute_default_rope_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor
def _compute_linear_scaling_rope_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
factor = rope_kwargs["factor"]
elif config is not None:
factor = config.rope_scaling["factor"]
# Gets the default RoPE parameters
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
# Then applies linear scaling to the frequencies.
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
# applying scaling to the inverse frequencies is equivalent.
inv_freq /= factor
return inv_freq, attention_factor
def _compute_dynamic_ntk_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length, used to update the dynamic RoPE at inference time.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
max_position_embeddings = rope_kwargs["max_position_embeddings"]
factor = rope_kwargs["factor"]
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"]
attention_factor = 1.0 # Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len = seq_len if seq_len is not None else max_position_embeddings
# Compute the inverse frequencies
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor
def _compute_yarn_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with NTK scaling. Please refer to the
[original paper](https://arxiv.org/abs/2309.00071)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# No need to keep BC with yarn, unreleased when this new pattern was created.
if len(rope_kwargs) > 0:
raise ValueError(
f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
)
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"]
# Sets the attention factor as suggested in the paper
attention_factor = config.rope_scaling.get("attention_factor")
if attention_factor is None:
attention_factor = 0.1 * math.log(factor) + 1.0
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
beta_fast = config.rope_scaling.get("beta_fast") or 32
beta_slow = config.rope_scaling.get("beta_slow") or 1
# Compute the inverse frequencies
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
"""Inverse dimension formula to find the dimension based on the number of rotations"""
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
"""Find dimension range bounds based on rotations"""
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)
def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device)
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq, attention_factor
def _compute_longrope_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
[original implementation](https://github.com/microsoft/LongRoPE)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# No need to keep BC with longrope, unreleased when this new pattern was created.
if len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
f"{rope_kwargs}"
)
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
long_factor = config.rope_scaling["long_factor"]
short_factor = config.rope_scaling["short_factor"]
factor = config.rope_scaling.get("factor")
attention_factor = config.rope_scaling.get("attention_factor")
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if hasattr(config, "original_max_position_embeddings"):
max_position_embeddings = config.original_max_position_embeddings
expanded_max_position_embeddings = config.max_position_embeddings
factor = expanded_max_position_embeddings / max_position_embeddings
else:
max_position_embeddings = config.max_position_embeddings
expanded_max_position_embeddings = max_position_embeddings * factor
# Sets the attention factor as suggested in the paper
if attention_factor is None:
if factor <= 1.0:
attention_factor = 1.0
else:
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
# Compute the inverse frequencies -- scaled based on the target sequence length
if expanded_max_position_embeddings > max_position_embeddings:
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
else:
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
return inv_freq, attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
# parameterizations, as long as the callable has the same signature.
ROPE_INIT_FUNCTIONS = {
"default": _compute_default_rope_parameters,
"linear": _compute_linear_scaling_rope_parameters,
"dynamic": _compute_dynamic_ntk_parameters,
"yarn": _compute_yarn_parameters,
"longrope": _compute_longrope_parameters,
}
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
missing_keys = required_keys - received_keys
if missing_keys:
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
if optional_keys is not None:
unused_keys = received_keys - required_keys - optional_keys
else:
unused_keys = received_keys - received_keys
if unused_keys:
raise KeyError(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
def _validate_default_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
required_keys = {"rope_type", "factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_yarn_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
raise ValueError(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
beta_fast = rope_scaling.get("beta_fast")
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
beta_slow = rope_scaling.get("beta_slow")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
if (beta_fast or 32) < (beta_slow or 1):
raise ValueError(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
)
def _validate_longrope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
required_keys = {"rope_type", "short_factor", "long_factor"}
optional_keys = {"attention_factor", "factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
short_factor = rope_scaling.get("short_factor")
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
raise ValueError(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
if not len(short_factor) == dim // 2:
raise ValueError(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
long_factor = rope_scaling.get("long_factor")
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
raise ValueError(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
if not len(long_factor) == dim // 2:
raise ValueError(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
# unique to longrope (= undesirable)
if hasattr(config, "original_max_position_embeddings"):
logger.warning_once(
"This model has set a `original_max_position_embeddings` field, to be used together with "
"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
"as it is compatible with most model architectures."
)
else:
factor = rope_scaling.get("factor")
if factor is None:
raise ValueError("Missing required keys in `rope_scaling`: 'factor'")
elif not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
ROPE_VALIDATION_FUNCTIONS = {
"default": _validate_default_rope_parameters,
"linear": _validate_linear_scaling_rope_parameters,
"dynamic": _validate_linear_scaling_rope_parameters, # `dynamic` has the same validation pattern as `linear`
"yarn": _validate_yarn_parameters,
"longrope": _validate_longrope_parameters,
}
def rope_config_validation(config: PretrainedConfig):
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
"""
rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
if rope_scaling is None:
return
possible_rope_types = set(ROPE_INIT_FUNCTIONS.keys())
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
if rope_type is None:
raise ValueError(
f"rope_scaling must contain a non-None 'rope_type' field. Possible options are {possible_rope_types}"
)
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None:
validation_fn(config)
else:
raise ValueError(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
)
...@@ -80,7 +80,8 @@ class ChameleonRMSNorm(nn.Module): ...@@ -80,7 +80,8 @@ class ChameleonRMSNorm(nn.Module):
ALL_LAYERNORM_LAYERS.append(ChameleonRMSNorm) ALL_LAYERNORM_LAYERS.append(ChameleonRMSNorm)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonRotaryEmbedding(nn.Module): class ChameleonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__() super().__init__()
...@@ -110,7 +111,8 @@ class ChameleonRotaryEmbedding(nn.Module): ...@@ -110,7 +111,8 @@ class ChameleonRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Chameleon # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding): class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
"""ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
...@@ -121,7 +123,8 @@ class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding): ...@@ -121,7 +123,8 @@ class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
return cos, sin return cos, sin
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Chameleon # copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding): class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
"""ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
...@@ -265,7 +268,8 @@ class ChameleonAttention(nn.Module): ...@@ -265,7 +268,8 @@ class ChameleonAttention(nn.Module):
self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim)) self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim))
self._init_rope() self._init_rope()
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon # copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon
# TODO(joao): add me back asap :)
def _init_rope(self): def _init_rope(self):
if self.config.rope_scaling is None: if self.config.rope_scaling is None:
self.rotary_emb = ChameleonRotaryEmbedding( self.rotary_emb = ChameleonRotaryEmbedding(
...@@ -358,7 +362,8 @@ class ChameleonAttention(nn.Module): ...@@ -358,7 +362,8 @@ class ChameleonAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon # copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonFlashAttention2(ChameleonAttention): class ChameleonFlashAttention2(ChameleonAttention):
""" """
Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays
...@@ -576,7 +581,8 @@ CHAMELEON_ATTENTION_CLASSES = { ...@@ -576,7 +581,8 @@ CHAMELEON_ATTENTION_CLASSES = {
} }
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
# TODO(joao): add me back asap :)
class ChameleonDecoderLayer(nn.Module): class ChameleonDecoderLayer(nn.Module):
def __init__(self, config: ChameleonConfig, layer_idx: int): def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__() super().__init__()
......
...@@ -295,7 +295,8 @@ class CohereAttention(nn.Module): ...@@ -295,7 +295,8 @@ class CohereAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere # copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
# TODO(joao): add me back asap :)
class CohereFlashAttention2(CohereAttention): class CohereFlashAttention2(CohereAttention):
""" """
Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays
...@@ -409,7 +410,8 @@ class CohereFlashAttention2(CohereAttention): ...@@ -409,7 +410,8 @@ class CohereFlashAttention2(CohereAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention Llama->Cohere # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention Llama->Cohere
# TODO(joao): add me back asap :)
class CohereSdpaAttention(CohereAttention): class CohereSdpaAttention(CohereAttention):
""" """
Cohere attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Cohere attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -697,7 +699,8 @@ COHERE_INPUTS_DOCSTRING = r""" ...@@ -697,7 +699,8 @@ COHERE_INPUTS_DOCSTRING = r"""
"The bare Cohere Model outputting raw hidden-states without any specific head on top.", "The bare Cohere Model outputting raw hidden-states without any specific head on top.",
COHERE_START_DOCSTRING, COHERE_START_DOCSTRING,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere # copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere
# TODO(joao): add me back asap :)
class CohereModel(CoherePreTrainedModel): class CohereModel(CoherePreTrainedModel):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`]
......
...@@ -1624,7 +1624,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel): ...@@ -1624,7 +1624,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel):
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1363,7 +1363,7 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel): ...@@ -1363,7 +1363,7 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel):
@add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -20,10 +20,7 @@ ...@@ -20,10 +20,7 @@
"""LLaMA model configuration""" """LLaMA model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...modeling_rope_utils import rope_config_validation
logger = logging.get_logger(__name__)
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig):
...@@ -84,22 +81,35 @@ class LlamaConfig(PretrainedConfig): ...@@ -84,22 +81,35 @@ class LlamaConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects
strategies: linear, dynamic and yarn. Their scaling factor must be a float greater than 1. The expected format is `max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update to determine which scaling to apply.
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how Expected contents:
these scaling strategies behave: `rope_type` (`str`):
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'],
experimental feature, subject to breaking API changes in future versions. with 'default' being the original RoPE implementation.
For the `yarn` strategy, the dictionary may also contain the following fields: `factor` (`float`, *optional*):
`original_max_position_embeddings` (`int`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
The original maximum sequence length. This is used to scale the RoPE embeddings. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
`max_position_embeddings`.
`attention_factor` (`float`, *optional*): `attention_factor` (`float`, *optional*):
The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the `original_max_position_embeddings/max_position_embeddings` ratio. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*): `beta_fast` (`float`, *optional*):
Parameter to set the boundary for extrapolation (only) in the linear ramp function. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*): `beta_slow` (`float`, *optional*):
Parameter to set the boundary for interpolation (only) in the linear ramp function. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
attention_bias (`bool`, *optional*, defaults to `False`): attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
...@@ -167,11 +177,13 @@ class LlamaConfig(PretrainedConfig): ...@@ -167,11 +177,13 @@ class LlamaConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias self.mlp_bias = mlp_bias
# Validate the correctness of rotary position embeddings parameters
rope_config_validation(self)
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
...@@ -179,60 +191,3 @@ class LlamaConfig(PretrainedConfig): ...@@ -179,60 +191,3 @@ class LlamaConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) < 2:
raise ValueError(
"`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
if rope_scaling_type != "yarn":
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6:
raise ValueError(
"`rope_scaling` with type "
f"{rope_scaling_type}"
" must be a dictionary with a maximum of six fields, `type`, `factor`,"
"`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, "
f"got {self.rope_scaling}"
)
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
attention_factor = self.rope_scaling.get("attention_factor", None)
beta_fast = self.rope_scaling.get("beta_fast", None)
beta_slow = self.rope_scaling.get("beta_slow", None)
if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}"
)
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
b_fast = beta_fast if beta_fast is not None else 32
b_slow = beta_slow if beta_slow is not None else 1
if b_fast < b_slow:
raise ValueError(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}"
)
...@@ -37,6 +37,7 @@ from ...modeling_outputs import ( ...@@ -37,6 +37,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import ( from ...utils import (
...@@ -75,24 +76,77 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) ...@@ -75,24 +76,77 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
class LlamaRotaryEmbedding(nn.Module): class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[LlamaConfig] = None,
):
super().__init__() super().__init__()
self.scaling_factor = scaling_factor # TODO (joao): remove the `if` below, only used for BC
self.dim = dim self.rope_kwargs = {}
self.max_position_embeddings = max_position_embeddings if config is None:
self.base = base logger.warning_once(
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
self.register_buffer("inv_freq", inv_freq, persistent=False) "`config` argument. All other arguments will be removed in v4.45"
# For BC we register cos and sin cached )
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling["type"])
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad() @torch.no_grad()
def forward(self, x, position_ids): def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size] if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): with torch.autocast(device_type=device_type, enabled=False):
...@@ -100,107 +154,37 @@ class LlamaRotaryEmbedding(nn.Module): ...@@ -100,107 +154,37 @@ class LlamaRotaryEmbedding(nn.Module):
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() cos = emb.cos()
sin = emb.sin() sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def forward(self, x, position_ids): def __init__(self, *args, **kwargs):
# difference to the original RoPE: a scaling factor is aplied to the position ids logger.warning_once(
position_ids = position_ids.float() / self.scaling_factor "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
cos, sin = super().forward(x, position_ids) "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
return cos, sin )
kwargs["rope_type"] = "linear"
super().__init__(*args, **kwargs)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def forward(self, x, position_ids): def __init__(self, *args, **kwargs):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length logger.warning_once(
seq_len = torch.max(position_ids) + 1 "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
if seq_len > self.max_position_embeddings: "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
base = self.base * ( "__init__)."
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
cos, sin = super().forward(x, position_ids)
return cos, sin
class LlamaYarnScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1,
original_max_position_embeddings=2048,
attention_factor=None,
beta_fast=32,
beta_slow=1,
device=None,
):
super().__init__(dim, max_position_embeddings, base, device, scaling_factor)
self.original_max_position_embeddings = original_max_position_embeddings
self.attention_factor = attention_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
if self.attention_factor is None:
# Recommended attention factor for LLaMA models.
# For more details please refer to https://arxiv.org/pdf/2309.00071, Eq. 22.
self.attention_factor = 0.1 * math.log(scaling_factor) + 1.0
self.compute_yarn_scaling(device)
# Inverse dimension formula to find the dimension based on the number of rotations
def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
# Find dimension range bounds based on rotations
def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)
def linear_ramp_mask(self, min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def forward(self, x, position_ids=None):
# Difference to the original RoPE: applies a scaling factor computed with
# the YaRN method (NTK-by-Parts + Attn Scaling)
# x: [bs, num_attention_heads, seq_len, head_size]
cos, sin = super().forward(x, position_ids)
cos = cos * self.mscale
sin = sin * self.mscale
return cos, sin
def compute_yarn_scaling(self, device):
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs)
low, high = self.find_correction_range(
self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings
) )
# Get n-dimensional rotational scaling corrected for extrapolation kwargs["rope_type"] = "dynamic"
inv_freq_mask = 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) super().__init__(*args, **kwargs)
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
self.register_buffer("inv_freq", inv_freq)
# Get n-dimensional magnitude scaling corrected for interpolation
self.mscale = self.attention_factor
def rotate_half(x): def rotate_half(x):
...@@ -317,51 +301,9 @@ class LlamaAttention(nn.Module): ...@@ -317,51 +301,9 @@ class LlamaAttention(nn.Module):
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope()
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
def _init_rope(self): self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
# Yarn parameters
kwargs = {
"dim": self.config.rope_scaling.get("original_max_position_embeddings", None),
"max_position_embeddings": self.config.rope_scaling.get("attention_factor", None),
"base": self.config.rope_scaling.get("beta_fast", None),
"scaling_factor": self.config.rope_scaling.get("beta_slow", None),
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "yarn":
self.rotary_emb = LlamaYarnScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward( def forward(
self, self,
...@@ -372,6 +314,7 @@ class LlamaAttention(nn.Module): ...@@ -372,6 +314,7 @@ class LlamaAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -402,7 +345,16 @@ class LlamaAttention(nn.Module): ...@@ -402,7 +345,16 @@ class LlamaAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
...@@ -471,6 +423,7 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -471,6 +423,7 @@ class LlamaFlashAttention2(LlamaAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache): if isinstance(past_key_value, StaticCache):
raise ValueError( raise ValueError(
...@@ -493,7 +446,16 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -493,7 +446,16 @@ class LlamaFlashAttention2(LlamaAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
...@@ -573,6 +535,7 @@ class LlamaSdpaAttention(LlamaAttention): ...@@ -573,6 +535,7 @@ class LlamaSdpaAttention(LlamaAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
...@@ -589,6 +552,7 @@ class LlamaSdpaAttention(LlamaAttention): ...@@ -589,6 +552,7 @@ class LlamaSdpaAttention(LlamaAttention):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -601,7 +565,16 @@ class LlamaSdpaAttention(LlamaAttention): ...@@ -601,7 +565,16 @@ class LlamaSdpaAttention(LlamaAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
...@@ -671,6 +644,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -671,6 +644,7 @@ class LlamaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
...@@ -688,6 +662,9 @@ class LlamaDecoderLayer(nn.Module): ...@@ -688,6 +662,9 @@ class LlamaDecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model into the model
...@@ -705,6 +682,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -705,6 +682,7 @@ class LlamaDecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs, **kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -867,6 +845,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -867,6 +845,7 @@ class LlamaModel(LlamaPreTrainedModel):
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
...@@ -933,10 +912,11 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -933,10 +912,11 @@ class LlamaModel(LlamaPreTrainedModel):
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
) )
# embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
...@@ -956,6 +936,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -956,6 +936,7 @@ class LlamaModel(LlamaPreTrainedModel):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
...@@ -966,6 +947,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -966,6 +947,7 @@ class LlamaModel(LlamaPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -1280,7 +1262,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): ...@@ -1280,7 +1262,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -85,7 +85,8 @@ class MistralRotaryEmbedding(nn.Module): ...@@ -85,7 +85,8 @@ class MistralRotaryEmbedding(nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad() @torch.no_grad()
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
# TODO(joao): add me back asap :)
def forward(self, x, position_ids): def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
...@@ -396,7 +397,8 @@ class MistralFlashAttention2(MistralAttention): ...@@ -396,7 +397,8 @@ class MistralFlashAttention2(MistralAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
# TODO(joao): add me back asap :)
class MistralSdpaAttention(MistralAttention): class MistralSdpaAttention(MistralAttention):
""" """
Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -492,7 +494,8 @@ MISTRAL_ATTENTION_CLASSES = { ...@@ -492,7 +494,8 @@ MISTRAL_ATTENTION_CLASSES = {
} }
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL
# TODO(joao): add me back asap :)
class MistralDecoderLayer(nn.Module): class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int): def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__() super().__init__()
...@@ -1146,7 +1149,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel): ...@@ -1146,7 +1149,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1362,7 +1362,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): ...@@ -1362,7 +1362,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -74,7 +74,8 @@ class OlmoLayerNorm(nn.Module): ...@@ -74,7 +74,8 @@ class OlmoLayerNorm(nn.Module):
ALL_LAYERNORM_LAYERS.append(OlmoLayerNorm) ALL_LAYERNORM_LAYERS.append(OlmoLayerNorm)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo
# TODO(joao): add me back asap :)
class OlmoRotaryEmbedding(nn.Module): class OlmoRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__() super().__init__()
...@@ -104,7 +105,8 @@ class OlmoRotaryEmbedding(nn.Module): ...@@ -104,7 +105,8 @@ class OlmoRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo
# TODO(joao): add me back asap :)
class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding): class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
"""OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
...@@ -115,7 +117,8 @@ class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding): ...@@ -115,7 +117,8 @@ class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
return cos, sin return cos, sin
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo # copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo
# TODO(joao): add me back asap :)
class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding): class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding):
"""OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
...@@ -202,7 +205,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: ...@@ -202,7 +205,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class OlmoAttention(nn.Module): class OlmoAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo
# TODO(joao): add me back asap :)
def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None): def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -549,7 +553,8 @@ class OlmoDecoderLayer(nn.Module): ...@@ -549,7 +553,8 @@ class OlmoDecoderLayer(nn.Module):
self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.input_layernorm = OlmoLayerNorm(config.hidden_size)
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
# TODO(joao): add me back asap :)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -768,7 +773,8 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -768,7 +773,8 @@ class OlmoModel(OlmoPreTrainedModel):
self.embed_tokens = value self.embed_tokens = value
@add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
# Copied from transformers.models.llama.modeling_llama.LlamaModel.forward # copied from transformers.models.llama.modeling_llama.LlamaModel.forward
# TODO(joao): add me back asap :)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
......
...@@ -999,7 +999,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): ...@@ -999,7 +999,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1282,7 +1282,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel): ...@@ -1282,7 +1282,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1278,7 +1278,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel): ...@@ -1278,7 +1278,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1370,7 +1370,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): ...@@ -1370,7 +1370,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1275,7 +1275,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): ...@@ -1275,7 +1275,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1153,7 +1153,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): ...@@ -1153,7 +1153,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -485,6 +485,9 @@ class WhisperTimeStampLogitsProcessor(metaclass=DummyObject): ...@@ -485,6 +485,9 @@ class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
ROPE_INIT_FUNCTIONS = None
class PreTrainedModel(metaclass=DummyObject): class PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -51,12 +51,7 @@ if is_torch_available(): ...@@ -51,12 +51,7 @@ if is_torch_available():
LlamaModel, LlamaModel,
LlamaTokenizer, LlamaTokenizer,
) )
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding, LlamaRotaryEmbedding
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding,
LlamaYarnScalingRotaryEmbedding,
)
class LlamaModelTester: class LlamaModelTester:
...@@ -431,9 +426,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -431,9 +426,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_model_rope_scaling(self): def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10 scaling_factor = 10
short_input_length = 10 short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5) long_input_length = int(config.max_position_embeddings * 1.5)
...@@ -446,11 +438,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -446,11 +438,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
position_ids_long = position_ids_long.unsqueeze(0) position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE # Sanity check original RoPE
original_rope = LlamaRotaryEmbedding( original_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long) original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
...@@ -458,12 +446,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -458,12 +446,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# Sanity check linear RoPE scaling # Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor" # New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = LlamaLinearScalingRotaryEmbedding( config.rope_scaling = {"type": "linear", "factor": scaling_factor}
head_dim, linear_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
...@@ -476,12 +460,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -476,12 +460,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# Sanity check Dynamic NTK RoPE scaling # Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases) # with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = LlamaDynamicNTKScalingRotaryEmbedding( config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
head_dim, ntk_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short) torch.testing.assert_close(ntk_cos_short, original_cos_short)
...@@ -493,12 +473,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -493,12 +473,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
# Sanity check Yarn RoPE scaling # Sanity check Yarn RoPE scaling
yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding( # Scaling should be over the entire input
head_dim, config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
max_position_embeddings=config.max_position_embeddings, yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
...@@ -512,6 +489,43 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -512,6 +489,43 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long) torch.testing.assert_close(yarn_sin_long, original_sin_long)
def test_rope_class_retrocompatibility(self):
# Delete me when we remove compatibility for the old API :)
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
config.rope_scaling = {"type": "linear", "factor": 10}
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# Old API -- under the hood, "type": "linear" is set and `LlamaRotaryEmbedding` is called
old_api_rope = LlamaLinearScalingRotaryEmbedding(
config.hidden_size // config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
old_cos_short, old_sin_short = old_api_rope(x, position_ids_short)
old_cos_long, old_sin_long = old_api_rope(x, position_ids_long)
# New API
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
new_api_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
new_cos_short, new_sin_short = new_api_rope(x, position_ids_short)
new_cos_long, new_sin_long = new_api_rope(x, position_ids_long)
# The results should match
torch.testing.assert_close(old_cos_short, new_cos_short)
torch.testing.assert_close(old_sin_short, new_sin_short)
torch.testing.assert_close(old_cos_long, new_cos_long)
torch.testing.assert_close(old_sin_long, new_sin_long)
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@require_bitsandbytes @require_bitsandbytes
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import LlamaConfig
from transformers.testing_utils import is_torch_available, require_torch, torch_device
if is_torch_available():
import torch
from transformers import ROPE_INIT_FUNCTIONS
from transformers.modeling_rope_utils import rope_config_validation
@require_torch
class RopeTest(unittest.TestCase):
def test_rope_validation(self):
config = LlamaConfig()
all_rope_types = ROPE_INIT_FUNCTIONS.keys()
# The base config is always valid (default RoPE)
rope_config_validation(config)
# If we explicitly set the other RoPE types, then validation should fail
for rope_type in all_rope_types:
if rope_type != "default":
config.rope_scaling = {"rope_type": rope_type}
with self.assertRaises(KeyError):
rope_config_validation(config)
# Parameters are exclusive to their own RoPE type, and should raise an exception if incorrectly passed
valid_param_mapping = {
"factor": ["linear", "dynamic", "yarn", "longrope"],
"attention_factor": ["yarn", "longrope"],
"beta_fast": ["yarn"],
"beta_slow": ["yarn"],
"short_factor": ["longrope"],
"long_factor": ["longrope"],
}
for rope_type in all_rope_types:
if rope_type == "default":
continue # checked above
for param, valid_rope_types in valid_param_mapping.items():
# Set `param` with a dummy value -- we want to test the dict key
config.rope_scaling = {"rope_type": rope_type, param: True}
if rope_type in valid_rope_types:
continue
else:
with self.assertRaises(KeyError):
rope_config_validation(config)
def test_default_rope_function_bc(self):
config = LlamaConfig()
device = torch_device
rope_kwargs = {
"rope_type": "default",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
}
rope_fn = ROPE_INIT_FUNCTIONS["default"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
def test_linear_rope_function_bc(self):
config = LlamaConfig()
config.rope_scaling = {"rope_type": "linear", "factor": 10.0}
device = torch_device
rope_kwargs = {
"rope_type": "linear",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
"factor": 10.0,
}
rope_fn = ROPE_INIT_FUNCTIONS["linear"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
def test_dynamic_rope_function_bc(self):
config = LlamaConfig()
config.rope_scaling = {"rope_type": "dynamic", "factor": 10.0}
device = torch_device
rope_kwargs = {
"rope_type": "dynamic",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
"factor": 10.0,
}
rope_fn = ROPE_INIT_FUNCTIONS["dynamic"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
# TODO(joao): numerical checks for the different RoPE fns
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment