Unverified Commit 2e113422 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files
parent 5a4a76ed
...@@ -1295,6 +1295,7 @@ else: ...@@ -1295,6 +1295,7 @@ else:
) )
_import_structure["modeling_flash_attention_utils"] = [] _import_structure["modeling_flash_attention_utils"] = []
_import_structure["modeling_outputs"] = [] _import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
_import_structure["modeling_utils"] = ["PreTrainedModel"] _import_structure["modeling_utils"] = ["PreTrainedModel"]
# PyTorch models structure # PyTorch models structure
...@@ -6010,6 +6011,7 @@ if TYPE_CHECKING: ...@@ -6010,6 +6011,7 @@ if TYPE_CHECKING:
WatermarkLogitsProcessor, WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor, WhisperTimeStampLogitsProcessor,
) )
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .models.albert import ( from .models.albert import (
AlbertForMaskedLM, AlbertForMaskedLM,
......
This diff is collapsed.
...@@ -80,7 +80,8 @@ class ChameleonRMSNorm(nn.Module): ...@@ -80,7 +80,8 @@ class ChameleonRMSNorm(nn.Module):
ALL_LAYERNORM_LAYERS.append(ChameleonRMSNorm) ALL_LAYERNORM_LAYERS.append(ChameleonRMSNorm)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonRotaryEmbedding(nn.Module): class ChameleonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__() super().__init__()
...@@ -110,7 +111,8 @@ class ChameleonRotaryEmbedding(nn.Module): ...@@ -110,7 +111,8 @@ class ChameleonRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Chameleon # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding): class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
"""ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
...@@ -121,7 +123,8 @@ class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding): ...@@ -121,7 +123,8 @@ class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
return cos, sin return cos, sin
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Chameleon # copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding): class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
"""ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
...@@ -265,7 +268,8 @@ class ChameleonAttention(nn.Module): ...@@ -265,7 +268,8 @@ class ChameleonAttention(nn.Module):
self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim)) self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim))
self._init_rope() self._init_rope()
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon # copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon
# TODO(joao): add me back asap :)
def _init_rope(self): def _init_rope(self):
if self.config.rope_scaling is None: if self.config.rope_scaling is None:
self.rotary_emb = ChameleonRotaryEmbedding( self.rotary_emb = ChameleonRotaryEmbedding(
...@@ -358,7 +362,8 @@ class ChameleonAttention(nn.Module): ...@@ -358,7 +362,8 @@ class ChameleonAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon # copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonFlashAttention2(ChameleonAttention): class ChameleonFlashAttention2(ChameleonAttention):
""" """
Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays
...@@ -576,7 +581,8 @@ CHAMELEON_ATTENTION_CLASSES = { ...@@ -576,7 +581,8 @@ CHAMELEON_ATTENTION_CLASSES = {
} }
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
# TODO(joao): add me back asap :)
class ChameleonDecoderLayer(nn.Module): class ChameleonDecoderLayer(nn.Module):
def __init__(self, config: ChameleonConfig, layer_idx: int): def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__() super().__init__()
......
...@@ -295,7 +295,8 @@ class CohereAttention(nn.Module): ...@@ -295,7 +295,8 @@ class CohereAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere # copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
# TODO(joao): add me back asap :)
class CohereFlashAttention2(CohereAttention): class CohereFlashAttention2(CohereAttention):
""" """
Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays
...@@ -409,7 +410,8 @@ class CohereFlashAttention2(CohereAttention): ...@@ -409,7 +410,8 @@ class CohereFlashAttention2(CohereAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention Llama->Cohere # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention Llama->Cohere
# TODO(joao): add me back asap :)
class CohereSdpaAttention(CohereAttention): class CohereSdpaAttention(CohereAttention):
""" """
Cohere attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Cohere attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -697,7 +699,8 @@ COHERE_INPUTS_DOCSTRING = r""" ...@@ -697,7 +699,8 @@ COHERE_INPUTS_DOCSTRING = r"""
"The bare Cohere Model outputting raw hidden-states without any specific head on top.", "The bare Cohere Model outputting raw hidden-states without any specific head on top.",
COHERE_START_DOCSTRING, COHERE_START_DOCSTRING,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere # copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere
# TODO(joao): add me back asap :)
class CohereModel(CoherePreTrainedModel): class CohereModel(CoherePreTrainedModel):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`]
......
...@@ -1624,7 +1624,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel): ...@@ -1624,7 +1624,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel):
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1363,7 +1363,7 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel): ...@@ -1363,7 +1363,7 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel):
@add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -20,10 +20,7 @@ ...@@ -20,10 +20,7 @@
"""LLaMA model configuration""" """LLaMA model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...modeling_rope_utils import rope_config_validation
logger = logging.get_logger(__name__)
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig):
...@@ -84,22 +81,35 @@ class LlamaConfig(PretrainedConfig): ...@@ -84,22 +81,35 @@ class LlamaConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects
strategies: linear, dynamic and yarn. Their scaling factor must be a float greater than 1. The expected format is `max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update to determine which scaling to apply.
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how Expected contents:
these scaling strategies behave: `rope_type` (`str`):
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'],
experimental feature, subject to breaking API changes in future versions. with 'default' being the original RoPE implementation.
For the `yarn` strategy, the dictionary may also contain the following fields: `factor` (`float`, *optional*):
`original_max_position_embeddings` (`int`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
The original maximum sequence length. This is used to scale the RoPE embeddings. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
`max_position_embeddings`.
`attention_factor` (`float`, *optional*): `attention_factor` (`float`, *optional*):
The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the `original_max_position_embeddings/max_position_embeddings` ratio. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*): `beta_fast` (`float`, *optional*):
Parameter to set the boundary for extrapolation (only) in the linear ramp function. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*): `beta_slow` (`float`, *optional*):
Parameter to set the boundary for interpolation (only) in the linear ramp function. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
attention_bias (`bool`, *optional*, defaults to `False`): attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
...@@ -167,11 +177,13 @@ class LlamaConfig(PretrainedConfig): ...@@ -167,11 +177,13 @@ class LlamaConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias self.mlp_bias = mlp_bias
# Validate the correctness of rotary position embeddings parameters
rope_config_validation(self)
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
...@@ -179,60 +191,3 @@ class LlamaConfig(PretrainedConfig): ...@@ -179,60 +191,3 @@ class LlamaConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) < 2:
raise ValueError(
"`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
if rope_scaling_type != "yarn":
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6:
raise ValueError(
"`rope_scaling` with type "
f"{rope_scaling_type}"
" must be a dictionary with a maximum of six fields, `type`, `factor`,"
"`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, "
f"got {self.rope_scaling}"
)
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
attention_factor = self.rope_scaling.get("attention_factor", None)
beta_fast = self.rope_scaling.get("beta_fast", None)
beta_slow = self.rope_scaling.get("beta_slow", None)
if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}"
)
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
b_fast = beta_fast if beta_fast is not None else 32
b_slow = beta_slow if beta_slow is not None else 1
if b_fast < b_slow:
raise ValueError(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}"
)
...@@ -85,7 +85,8 @@ class MistralRotaryEmbedding(nn.Module): ...@@ -85,7 +85,8 @@ class MistralRotaryEmbedding(nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad() @torch.no_grad()
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
# TODO(joao): add me back asap :)
def forward(self, x, position_ids): def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
...@@ -396,7 +397,8 @@ class MistralFlashAttention2(MistralAttention): ...@@ -396,7 +397,8 @@ class MistralFlashAttention2(MistralAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
# TODO(joao): add me back asap :)
class MistralSdpaAttention(MistralAttention): class MistralSdpaAttention(MistralAttention):
""" """
Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -492,7 +494,8 @@ MISTRAL_ATTENTION_CLASSES = { ...@@ -492,7 +494,8 @@ MISTRAL_ATTENTION_CLASSES = {
} }
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL
# TODO(joao): add me back asap :)
class MistralDecoderLayer(nn.Module): class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int): def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__() super().__init__()
...@@ -1146,7 +1149,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel): ...@@ -1146,7 +1149,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1362,7 +1362,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): ...@@ -1362,7 +1362,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -74,7 +74,8 @@ class OlmoLayerNorm(nn.Module): ...@@ -74,7 +74,8 @@ class OlmoLayerNorm(nn.Module):
ALL_LAYERNORM_LAYERS.append(OlmoLayerNorm) ALL_LAYERNORM_LAYERS.append(OlmoLayerNorm)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo
# TODO(joao): add me back asap :)
class OlmoRotaryEmbedding(nn.Module): class OlmoRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__() super().__init__()
...@@ -104,7 +105,8 @@ class OlmoRotaryEmbedding(nn.Module): ...@@ -104,7 +105,8 @@ class OlmoRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo
# TODO(joao): add me back asap :)
class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding): class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
"""OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
...@@ -115,7 +117,8 @@ class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding): ...@@ -115,7 +117,8 @@ class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
return cos, sin return cos, sin
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo # copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo
# TODO(joao): add me back asap :)
class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding): class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding):
"""OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
...@@ -202,7 +205,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: ...@@ -202,7 +205,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class OlmoAttention(nn.Module): class OlmoAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo
# TODO(joao): add me back asap :)
def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None): def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -549,7 +553,8 @@ class OlmoDecoderLayer(nn.Module): ...@@ -549,7 +553,8 @@ class OlmoDecoderLayer(nn.Module):
self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.input_layernorm = OlmoLayerNorm(config.hidden_size)
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
# TODO(joao): add me back asap :)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -768,7 +773,8 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -768,7 +773,8 @@ class OlmoModel(OlmoPreTrainedModel):
self.embed_tokens = value self.embed_tokens = value
@add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
# Copied from transformers.models.llama.modeling_llama.LlamaModel.forward # copied from transformers.models.llama.modeling_llama.LlamaModel.forward
# TODO(joao): add me back asap :)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
......
...@@ -999,7 +999,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): ...@@ -999,7 +999,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1282,7 +1282,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel): ...@@ -1282,7 +1282,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1278,7 +1278,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel): ...@@ -1278,7 +1278,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1370,7 +1370,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): ...@@ -1370,7 +1370,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1275,7 +1275,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): ...@@ -1275,7 +1275,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -1153,7 +1153,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): ...@@ -1153,7 +1153,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
...@@ -485,6 +485,9 @@ class WhisperTimeStampLogitsProcessor(metaclass=DummyObject): ...@@ -485,6 +485,9 @@ class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
ROPE_INIT_FUNCTIONS = None
class PreTrainedModel(metaclass=DummyObject): class PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -51,12 +51,7 @@ if is_torch_available(): ...@@ -51,12 +51,7 @@ if is_torch_available():
LlamaModel, LlamaModel,
LlamaTokenizer, LlamaTokenizer,
) )
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding, LlamaRotaryEmbedding
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding,
LlamaYarnScalingRotaryEmbedding,
)
class LlamaModelTester: class LlamaModelTester:
...@@ -431,9 +426,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -431,9 +426,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_model_rope_scaling(self): def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10 scaling_factor = 10
short_input_length = 10 short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5) long_input_length = int(config.max_position_embeddings * 1.5)
...@@ -446,11 +438,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -446,11 +438,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
position_ids_long = position_ids_long.unsqueeze(0) position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE # Sanity check original RoPE
original_rope = LlamaRotaryEmbedding( original_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long) original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
...@@ -458,12 +446,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -458,12 +446,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# Sanity check linear RoPE scaling # Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor" # New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = LlamaLinearScalingRotaryEmbedding( config.rope_scaling = {"type": "linear", "factor": scaling_factor}
head_dim, linear_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
...@@ -476,12 +460,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -476,12 +460,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# Sanity check Dynamic NTK RoPE scaling # Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases) # with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = LlamaDynamicNTKScalingRotaryEmbedding( config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
head_dim, ntk_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short) torch.testing.assert_close(ntk_cos_short, original_cos_short)
...@@ -493,12 +473,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -493,12 +473,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
# Sanity check Yarn RoPE scaling # Sanity check Yarn RoPE scaling
yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding( # Scaling should be over the entire input
head_dim, config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
max_position_embeddings=config.max_position_embeddings, yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
...@@ -512,6 +489,43 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -512,6 +489,43 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long) torch.testing.assert_close(yarn_sin_long, original_sin_long)
def test_rope_class_retrocompatibility(self):
# Delete me when we remove compatibility for the old API :)
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
config.rope_scaling = {"type": "linear", "factor": 10}
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# Old API -- under the hood, "type": "linear" is set and `LlamaRotaryEmbedding` is called
old_api_rope = LlamaLinearScalingRotaryEmbedding(
config.hidden_size // config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
old_cos_short, old_sin_short = old_api_rope(x, position_ids_short)
old_cos_long, old_sin_long = old_api_rope(x, position_ids_long)
# New API
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
new_api_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
new_cos_short, new_sin_short = new_api_rope(x, position_ids_short)
new_cos_long, new_sin_long = new_api_rope(x, position_ids_long)
# The results should match
torch.testing.assert_close(old_cos_short, new_cos_short)
torch.testing.assert_close(old_sin_short, new_sin_short)
torch.testing.assert_close(old_cos_long, new_cos_long)
torch.testing.assert_close(old_sin_long, new_sin_long)
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@require_bitsandbytes @require_bitsandbytes
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import LlamaConfig
from transformers.testing_utils import is_torch_available, require_torch, torch_device
if is_torch_available():
import torch
from transformers import ROPE_INIT_FUNCTIONS
from transformers.modeling_rope_utils import rope_config_validation
@require_torch
class RopeTest(unittest.TestCase):
def test_rope_validation(self):
config = LlamaConfig()
all_rope_types = ROPE_INIT_FUNCTIONS.keys()
# The base config is always valid (default RoPE)
rope_config_validation(config)
# If we explicitly set the other RoPE types, then validation should fail
for rope_type in all_rope_types:
if rope_type != "default":
config.rope_scaling = {"rope_type": rope_type}
with self.assertRaises(KeyError):
rope_config_validation(config)
# Parameters are exclusive to their own RoPE type, and should raise an exception if incorrectly passed
valid_param_mapping = {
"factor": ["linear", "dynamic", "yarn", "longrope"],
"attention_factor": ["yarn", "longrope"],
"beta_fast": ["yarn"],
"beta_slow": ["yarn"],
"short_factor": ["longrope"],
"long_factor": ["longrope"],
}
for rope_type in all_rope_types:
if rope_type == "default":
continue # checked above
for param, valid_rope_types in valid_param_mapping.items():
# Set `param` with a dummy value -- we want to test the dict key
config.rope_scaling = {"rope_type": rope_type, param: True}
if rope_type in valid_rope_types:
continue
else:
with self.assertRaises(KeyError):
rope_config_validation(config)
def test_default_rope_function_bc(self):
config = LlamaConfig()
device = torch_device
rope_kwargs = {
"rope_type": "default",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
}
rope_fn = ROPE_INIT_FUNCTIONS["default"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
def test_linear_rope_function_bc(self):
config = LlamaConfig()
config.rope_scaling = {"rope_type": "linear", "factor": 10.0}
device = torch_device
rope_kwargs = {
"rope_type": "linear",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
"factor": 10.0,
}
rope_fn = ROPE_INIT_FUNCTIONS["linear"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
def test_dynamic_rope_function_bc(self):
config = LlamaConfig()
config.rope_scaling = {"rope_type": "dynamic", "factor": 10.0}
device = torch_device
rope_kwargs = {
"rope_type": "dynamic",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
"factor": 10.0,
}
rope_fn = ROPE_INIT_FUNCTIONS["dynamic"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
# TODO(joao): numerical checks for the different RoPE fns
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment