"docs/vscode:/vscode.git/clone" did not exist on "cee2d2135fbedc65072eff00c48bdf8bb36ad57a"
Unverified Commit 2e113422 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files
parent 5a4a76ed
......@@ -1295,6 +1295,7 @@ else:
)
_import_structure["modeling_flash_attention_utils"] = []
_import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
_import_structure["modeling_utils"] = ["PreTrainedModel"]
# PyTorch models structure
......@@ -6010,6 +6011,7 @@ if TYPE_CHECKING:
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
from .modeling_utils import PreTrainedModel
from .models.albert import (
AlbertForMaskedLM,
......
This diff is collapsed.
......@@ -80,7 +80,8 @@ class ChameleonRMSNorm(nn.Module):
ALL_LAYERNORM_LAYERS.append(ChameleonRMSNorm)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
......@@ -110,7 +111,8 @@ class ChameleonRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Chameleon
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
"""ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
......@@ -121,7 +123,8 @@ class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
return cos, sin
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Chameleon
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
"""ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
......@@ -265,7 +268,8 @@ class ChameleonAttention(nn.Module):
self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim))
self._init_rope()
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon
# copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon
# TODO(joao): add me back asap :)
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = ChameleonRotaryEmbedding(
......@@ -358,7 +362,8 @@ class ChameleonAttention(nn.Module):
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonFlashAttention2(ChameleonAttention):
"""
Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays
......@@ -576,7 +581,8 @@ CHAMELEON_ATTENTION_CLASSES = {
}
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
# TODO(joao): add me back asap :)
class ChameleonDecoderLayer(nn.Module):
def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__()
......
......@@ -295,7 +295,8 @@ class CohereAttention(nn.Module):
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
# TODO(joao): add me back asap :)
class CohereFlashAttention2(CohereAttention):
"""
Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays
......@@ -409,7 +410,8 @@ class CohereFlashAttention2(CohereAttention):
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention Llama->Cohere
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention Llama->Cohere
# TODO(joao): add me back asap :)
class CohereSdpaAttention(CohereAttention):
"""
Cohere attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
......@@ -697,7 +699,8 @@ COHERE_INPUTS_DOCSTRING = r"""
"The bare Cohere Model outputting raw hidden-states without any specific head on top.",
COHERE_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere
# copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere
# TODO(joao): add me back asap :)
class CohereModel(CoherePreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`]
......
......@@ -1624,7 +1624,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel):
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
......@@ -1363,7 +1363,7 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel):
@add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
......@@ -20,10 +20,7 @@
"""LLaMA model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
from ...modeling_rope_utils import rope_config_validation
class LlamaConfig(PretrainedConfig):
......@@ -84,22 +81,35 @@ class LlamaConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear, dynamic and yarn. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
For the `yarn` strategy, the dictionary may also contain the following fields:
`original_max_position_embeddings` (`int`, *optional*):
The original maximum sequence length. This is used to scale the RoPE embeddings.
Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects
`max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value
to determine which scaling to apply.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'],
with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
`max_position_embeddings`.
`attention_factor` (`float`, *optional*):
The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the `original_max_position_embeddings/max_position_embeddings` ratio.
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Parameter to set the boundary for extrapolation (only) in the linear ramp function.
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Parameter to set the boundary for interpolation (only) in the linear ramp function.
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
......@@ -167,11 +177,13 @@ class LlamaConfig(PretrainedConfig):
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
# Validate the correctness of rotary position embeddings parameters
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
......@@ -179,60 +191,3 @@ class LlamaConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) < 2:
raise ValueError(
"`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
if rope_scaling_type != "yarn":
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6:
raise ValueError(
"`rope_scaling` with type "
f"{rope_scaling_type}"
" must be a dictionary with a maximum of six fields, `type`, `factor`,"
"`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, "
f"got {self.rope_scaling}"
)
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
attention_factor = self.rope_scaling.get("attention_factor", None)
beta_fast = self.rope_scaling.get("beta_fast", None)
beta_slow = self.rope_scaling.get("beta_slow", None)
if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}"
)
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
b_fast = beta_fast if beta_fast is not None else 32
b_slow = beta_slow if beta_slow is not None else 1
if b_fast < b_slow:
raise ValueError(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}"
)
......@@ -85,7 +85,8 @@ class MistralRotaryEmbedding(nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
# TODO(joao): add me back asap :)
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
......@@ -396,7 +397,8 @@ class MistralFlashAttention2(MistralAttention):
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
# TODO(joao): add me back asap :)
class MistralSdpaAttention(MistralAttention):
"""
Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
......@@ -492,7 +494,8 @@ MISTRAL_ATTENTION_CLASSES = {
}
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL
# TODO(joao): add me back asap :)
class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__()
......@@ -1146,7 +1149,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
......@@ -1362,7 +1362,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
......@@ -74,7 +74,8 @@ class OlmoLayerNorm(nn.Module):
ALL_LAYERNORM_LAYERS.append(OlmoLayerNorm)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo
# TODO(joao): add me back asap :)
class OlmoRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
......@@ -104,7 +105,8 @@ class OlmoRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo
# TODO(joao): add me back asap :)
class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
"""OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
......@@ -115,7 +117,8 @@ class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
return cos, sin
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo
# TODO(joao): add me back asap :)
class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding):
"""OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
......@@ -202,7 +205,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class OlmoAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo
# copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo
# TODO(joao): add me back asap :)
def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
......@@ -549,7 +553,8 @@ class OlmoDecoderLayer(nn.Module):
self.input_layernorm = OlmoLayerNorm(config.hidden_size)
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
# TODO(joao): add me back asap :)
def forward(
self,
hidden_states: torch.Tensor,
......@@ -768,7 +773,8 @@ class OlmoModel(OlmoPreTrainedModel):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
# Copied from transformers.models.llama.modeling_llama.LlamaModel.forward
# copied from transformers.models.llama.modeling_llama.LlamaModel.forward
# TODO(joao): add me back asap :)
def forward(
self,
input_ids: torch.LongTensor = None,
......
......@@ -999,7 +999,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
......@@ -1282,7 +1282,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
......@@ -1278,7 +1278,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
......@@ -1370,7 +1370,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
......@@ -1275,7 +1275,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
......@@ -1153,7 +1153,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
......
......@@ -485,6 +485,9 @@ class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"])
ROPE_INIT_FUNCTIONS = None
class PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -51,12 +51,7 @@ if is_torch_available():
LlamaModel,
LlamaTokenizer,
)
from transformers.models.llama.modeling_llama import (
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding,
LlamaYarnScalingRotaryEmbedding,
)
from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding, LlamaRotaryEmbedding
class LlamaModelTester:
......@@ -431,9 +426,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
......@@ -446,11 +438,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE
original_rope = LlamaRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
......@@ -458,12 +446,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = LlamaLinearScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
linear_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
......@@ -476,12 +460,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = LlamaDynamicNTKScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
ntk_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
......@@ -493,12 +473,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
# Sanity check Yarn RoPE scaling
yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
# Scaling should be over the entire input
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
......@@ -512,6 +489,43 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
def test_rope_class_retrocompatibility(self):
# Delete me when we remove compatibility for the old API :)
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
config.rope_scaling = {"type": "linear", "factor": 10}
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# Old API -- under the hood, "type": "linear" is set and `LlamaRotaryEmbedding` is called
old_api_rope = LlamaLinearScalingRotaryEmbedding(
config.hidden_size // config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
old_cos_short, old_sin_short = old_api_rope(x, position_ids_short)
old_cos_long, old_sin_long = old_api_rope(x, position_ids_long)
# New API
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
new_api_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
new_cos_short, new_sin_short = new_api_rope(x, position_ids_short)
new_cos_long, new_sin_long = new_api_rope(x, position_ids_long)
# The results should match
torch.testing.assert_close(old_cos_short, new_cos_short)
torch.testing.assert_close(old_sin_short, new_sin_short)
torch.testing.assert_close(old_cos_long, new_cos_long)
torch.testing.assert_close(old_sin_long, new_sin_long)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import LlamaConfig
from transformers.testing_utils import is_torch_available, require_torch, torch_device
if is_torch_available():
import torch
from transformers import ROPE_INIT_FUNCTIONS
from transformers.modeling_rope_utils import rope_config_validation
@require_torch
class RopeTest(unittest.TestCase):
def test_rope_validation(self):
config = LlamaConfig()
all_rope_types = ROPE_INIT_FUNCTIONS.keys()
# The base config is always valid (default RoPE)
rope_config_validation(config)
# If we explicitly set the other RoPE types, then validation should fail
for rope_type in all_rope_types:
if rope_type != "default":
config.rope_scaling = {"rope_type": rope_type}
with self.assertRaises(KeyError):
rope_config_validation(config)
# Parameters are exclusive to their own RoPE type, and should raise an exception if incorrectly passed
valid_param_mapping = {
"factor": ["linear", "dynamic", "yarn", "longrope"],
"attention_factor": ["yarn", "longrope"],
"beta_fast": ["yarn"],
"beta_slow": ["yarn"],
"short_factor": ["longrope"],
"long_factor": ["longrope"],
}
for rope_type in all_rope_types:
if rope_type == "default":
continue # checked above
for param, valid_rope_types in valid_param_mapping.items():
# Set `param` with a dummy value -- we want to test the dict key
config.rope_scaling = {"rope_type": rope_type, param: True}
if rope_type in valid_rope_types:
continue
else:
with self.assertRaises(KeyError):
rope_config_validation(config)
def test_default_rope_function_bc(self):
config = LlamaConfig()
device = torch_device
rope_kwargs = {
"rope_type": "default",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
}
rope_fn = ROPE_INIT_FUNCTIONS["default"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
def test_linear_rope_function_bc(self):
config = LlamaConfig()
config.rope_scaling = {"rope_type": "linear", "factor": 10.0}
device = torch_device
rope_kwargs = {
"rope_type": "linear",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
"factor": 10.0,
}
rope_fn = ROPE_INIT_FUNCTIONS["linear"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
def test_dynamic_rope_function_bc(self):
config = LlamaConfig()
config.rope_scaling = {"rope_type": "dynamic", "factor": 10.0}
device = torch_device
rope_kwargs = {
"rope_type": "dynamic",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
"factor": 10.0,
}
rope_fn = ROPE_INIT_FUNCTIONS["dynamic"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
# TODO(joao): numerical checks for the different RoPE fns
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment