Unverified Commit a8b70304 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update `rope_scaling` to `rope_parameters` in preparation for Transformers v5 (#28542)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent d44e9df7
...@@ -872,12 +872,12 @@ steps: ...@@ -872,12 +872,12 @@ steps:
optional: true optional: true
commands: commands:
- pip install --upgrade git+https://github.com/huggingface/transformers - pip install --upgrade git+https://github.com/huggingface/transformers
- pytest -v -s tests/models/test_initialization.py -k 'not (Gemma3 or ModernBert or Qwen2_5_VL or Qwen2_5vl or Qwen2VL or TransformersMultiModalEmbeddingModel or TransformersMultiModalForSequenceClassification or Ultravox or Phi4Multimodal or LlavaNextVideo or MiniCPMO or Lfm2Moe or PaliGemma or RobertaForSequenceClassification or Ovis2_5 or Fuyu or DeepseekOCR or KimiVL)' - pytest -v -s tests/models/test_initialization.py -k 'not (Ultravox or Phi4Multimodal or MiniCPMO or Lfm2Moe or RobertaForSequenceClassification or Ovis2_5 or DeepseekOCR or KimiVL)'
- pytest -v -s tests/models/test_transformers.py - pytest -v -s tests/models/test_transformers.py
# - pytest -v -s tests/models/multimodal/processing/ # - pytest -v -s tests/models/multimodal/processing/
- pytest -v -s tests/models/multimodal/test_mapping.py -k 'not (Gemma3 or Qwen2VL or Qwen2_5_VL)' - pytest -v -s tests/models/multimodal/test_mapping.py
- python3 examples/offline_inference/basic/chat.py - python3 examples/offline_inference/basic/chat.py
# - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
# Whisper needs spawn method to avoid deadlock # Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# #
# The CSV file (named with current date/time) contains these columns: # The CSV file (named with current date/time) contains these columns:
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position, # model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99, # is_neox_style, rope_parameters, dtype, torch_mean, torch_median, torch_p99,
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max, # torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
# speedup # speedup
# #
...@@ -86,9 +86,8 @@ def benchmark_mrope( ...@@ -86,9 +86,8 @@ def benchmark_mrope(
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
max_position: int = 8192, max_position: int = 8192,
rope_theta: float = 10000,
is_neox_style: bool = True, is_neox_style: bool = True,
rope_scaling: dict[str, Any] = None, rope_parameters: dict[str, Any] | None = None,
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
seed: int = 0, seed: int = 0,
warmup_iter: int = 10, warmup_iter: int = 10,
...@@ -102,9 +101,8 @@ def benchmark_mrope( ...@@ -102,9 +101,8 @@ def benchmark_mrope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim, rotary_dim=head_dim,
max_position=max_position, max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_scaling=rope_scaling, rope_parameters=rope_parameters,
dtype=dtype, dtype=dtype,
).to(device=device) ).to(device=device)
...@@ -203,9 +201,8 @@ def benchmark_mrope( ...@@ -203,9 +201,8 @@ def benchmark_mrope(
num_kv_heads, num_kv_heads,
head_dim, head_dim,
max_position, max_position,
rope_theta,
is_neox_style, is_neox_style,
str(rope_scaling), str(rope_parameters),
str(dtype).split(".")[-1], str(dtype).split(".")[-1],
torch_stats["mean"], torch_stats["mean"],
torch_stats["median"], torch_stats["median"],
...@@ -255,9 +252,8 @@ if __name__ == "__main__": ...@@ -255,9 +252,8 @@ if __name__ == "__main__":
"num_kv_heads", "num_kv_heads",
"head_dim", "head_dim",
"max_position", "max_position",
"rope_theta",
"is_neox_style", "is_neox_style",
"rope_scaling", "rope_parameters",
"dtype", "dtype",
"torch_mean", "torch_mean",
"torch_median", "torch_median",
...@@ -303,7 +299,7 @@ if __name__ == "__main__": ...@@ -303,7 +299,7 @@ if __name__ == "__main__":
q_size = num_heads * head_dim q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim kv_size = num_kv_heads * head_dim
is_neox_style = True is_neox_style = True
rope_theta = config.rope_theta rope_parameters = config.rope_parameters
max_position = config.max_position_embeddings max_position = config.max_position_embeddings
for num_tokens in num_tokens_list: for num_tokens in num_tokens_list:
...@@ -315,9 +311,8 @@ if __name__ == "__main__": ...@@ -315,9 +311,8 @@ if __name__ == "__main__":
num_heads=num_heads, num_heads=num_heads,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
max_position=max_position, max_position=max_position,
rope_theta=rope_theta,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_scaling=config.rope_scaling, rope_parameters=rope_parameters,
dtype=getattr(torch, args.dtype), dtype=getattr(torch, args.dtype),
seed=args.seed, seed=args.seed,
warmup_iter=args.warmup_iter, warmup_iter=args.warmup_iter,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
This script demonstrates how to extend the context length This script demonstrates how to extend the context length
of a Qwen model using the YARN method (rope_scaling) of a Qwen model using the YARN method (rope_parameters)
and run a simple chat example. and run a simple chat example.
Usage: Usage:
...@@ -19,8 +19,8 @@ def create_llm(): ...@@ -19,8 +19,8 @@ def create_llm():
# Use yarn to extend context # Use yarn to extend context
hf_overrides = { hf_overrides = {
"rope_parameters": {
"rope_theta": rope_theta, "rope_theta": rope_theta,
"rope_scaling": {
"rope_type": "yarn", "rope_type": "yarn",
"factor": factor, "factor": factor,
"original_max_position_embeddings": original_max_position_embeddings, "original_max_position_embeddings": original_max_position_embeddings,
......
...@@ -137,7 +137,7 @@ class TestRotaryEmbedding(torch.nn.Module): ...@@ -137,7 +137,7 @@ class TestRotaryEmbedding(torch.nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.rotary_dim, rotary_dim=self.rotary_dim,
max_position=max_position, max_position=max_position,
base=base, rope_parameters={"rope_type": "default", "rope_theta": base},
) )
def forward(self, positions, q, k): def forward(self, positions, q, k):
...@@ -172,7 +172,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module): ...@@ -172,7 +172,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
base=base, rope_parameters={"rope_type": "default", "rope_theta": base},
) )
def forward(self, positions, hidden_states): def forward(self, positions, hidden_states):
......
...@@ -5,11 +5,11 @@ from typing import NamedTuple ...@@ -5,11 +5,11 @@ from typing import NamedTuple
import pytest import pytest
import torch import torch
from packaging.version import Version from packaging.version import Version
from transformers import AutoConfig
from transformers import __version__ as TRANSFORMERS_VERSION from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...@@ -98,8 +98,7 @@ def test_mrope( ...@@ -98,8 +98,7 @@ def test_mrope(
atol = model_info.atol atol = model_info.atol
rtol = model_info.rtol rtol = model_info.rtol
config = AutoConfig.from_pretrained(model_name) config = get_config(model_name, False).get_text_config()
config = config.get_text_config()
# get the model config # get the model config
total_num_kv_heads = config.num_key_value_heads total_num_kv_heads = config.num_key_value_heads
...@@ -113,7 +112,6 @@ def test_mrope( ...@@ -113,7 +112,6 @@ def test_mrope(
) )
is_neox_style = True is_neox_style = True
rope_theta = config.rope_theta
max_position = config.max_position_embeddings max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor) rotary_dim = int(head_dim * partial_rotary_factor)
...@@ -122,9 +120,8 @@ def test_mrope( ...@@ -122,9 +120,8 @@ def test_mrope(
head_size=head_dim, head_size=head_dim,
rotary_dim=rotary_dim, rotary_dim=rotary_dim,
max_position=max_position, max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_scaling=config.rope_scaling, rope_parameters=config.rope_parameters,
dtype=dtype, dtype=dtype,
).to(device=device) ).to(device=device)
...@@ -173,8 +170,7 @@ def test_mrope_torch_compile_tracing( ...@@ -173,8 +170,7 @@ def test_mrope_torch_compile_tracing(
atol = model_info.atol atol = model_info.atol
rtol = model_info.rtol rtol = model_info.rtol
config = AutoConfig.from_pretrained(model_name) config = get_config(model_name, False).get_text_config()
config = config.get_text_config()
# get the model config # get the model config
total_num_kv_heads = config.num_key_value_heads total_num_kv_heads = config.num_key_value_heads
...@@ -187,7 +183,6 @@ def test_mrope_torch_compile_tracing( ...@@ -187,7 +183,6 @@ def test_mrope_torch_compile_tracing(
else config.hidden_size // total_num_heads else config.hidden_size // total_num_heads
) )
is_neox_style = True is_neox_style = True
rope_theta = config.rope_theta
max_position = config.max_position_embeddings max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor) rotary_dim = int(head_dim * partial_rotary_factor)
...@@ -196,9 +191,8 @@ def test_mrope_torch_compile_tracing( ...@@ -196,9 +191,8 @@ def test_mrope_torch_compile_tracing(
head_size=head_dim, head_size=head_dim,
rotary_dim=rotary_dim, rotary_dim=rotary_dim,
max_position=max_position, max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_scaling=config.rope_scaling, rope_parameters=config.rope_parameters,
dtype=dtype, dtype=dtype,
).to(device=device) ).to(device=device)
......
...@@ -74,7 +74,7 @@ def test_rotary_embedding( ...@@ -74,7 +74,7 @@ def test_rotary_embedding(
device: str, device: str,
use_key: bool, use_key: bool,
max_position: int = 8192, max_position: int = 8192,
base: float = 10000, rope_theta: float = 10000,
) -> None: ) -> None:
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
...@@ -83,7 +83,8 @@ def test_rotary_embedding( ...@@ -83,7 +83,8 @@ def test_rotary_embedding(
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=torch.get_default_device()) rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
...@@ -120,9 +121,9 @@ def test_rotary_embedding( ...@@ -120,9 +121,9 @@ def test_rotary_embedding(
@torch.inference_mode() @torch.inference_mode()
def test_rope_module_cache(): def test_rope_module_cache():
MAX_POSITIONS = [123, 1234] MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000] ROPE_THETAS = [10000, 1000000]
ROPE_SCALINGS = ( ROPE_PARAMETERS = (
None, {"rope_type": "default"},
{"rope_type": "linear", "factor": (1,)}, {"rope_type": "linear", "factor": (1,)},
{"rope_type": "dynamic", "factor": 1}, {"rope_type": "dynamic", "factor": 1},
) )
...@@ -130,9 +131,9 @@ def test_rope_module_cache(): ...@@ -130,9 +131,9 @@ def test_rope_module_cache():
HEAD_SIZES, HEAD_SIZES,
ROTARY_DIMS, ROTARY_DIMS,
MAX_POSITIONS, MAX_POSITIONS,
BASES, ROPE_THETAS,
IS_NEOX_STYLE, IS_NEOX_STYLE,
ROPE_SCALINGS, ROPE_PARAMETERS,
DTYPES, DTYPES,
) )
rope_setting_id_map: dict[str, int] = {} rope_setting_id_map: dict[str, int] = {}
...@@ -141,20 +142,20 @@ def test_rope_module_cache(): ...@@ -141,20 +142,20 @@ def test_rope_module_cache():
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
base, rope_theta,
is_neox_stype, is_neox_style,
rope_scaling, rope_parameters,
dtype, dtype,
) = setting ) = setting
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta
rope = get_rope( rope = get_rope(
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
base, is_neox_style,
is_neox_stype, rope_parameters,
rope_scaling,
dtype, dtype,
) )
# different settings cannot share the same rope module # different settings cannot share the same rope module
...@@ -168,20 +169,20 @@ def test_rope_module_cache(): ...@@ -168,20 +169,20 @@ def test_rope_module_cache():
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
base, rope_theta,
is_neox_stype, is_neox_style,
rope_scaling, rope_parameters,
dtype, dtype,
) = setting ) = setting
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta
rope = get_rope( rope = get_rope(
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
base, is_neox_style,
is_neox_stype, rope_parameters,
rope_scaling,
dtype, dtype,
) )
# check if cache take effect # check if cache take effect
......
...@@ -201,7 +201,7 @@ class ModelConfig: ...@@ -201,7 +201,7 @@ class ModelConfig:
sliding_window: int = 128 sliding_window: int = 128
initial_context_length: int = 4096 initial_context_length: int = 4096
rope_theta: float = 150000.0 rope_theta: float = 150000.0
rope_scaling_factor: float = 32.0 rope_parameters_factor: float = 32.0
rope_ntk_alpha: float = 1.0 rope_ntk_alpha: float = 1.0
rope_ntk_beta: float = 32.0 rope_ntk_beta: float = 32.0
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: SIM117 # ruff: noqa: SIM117
from typing import Any
import pytest import pytest
from ...utils import EmbedModelInfo from ...utils import EmbedModelInfo
...@@ -79,8 +81,8 @@ def test_set_max_model_len_illegal(model_info, vllm_runner): ...@@ -79,8 +81,8 @@ def test_set_max_model_len_illegal(model_info, vllm_runner):
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_use_rope_scaling_legal(model_info, vllm_runner): def test_use_rope_scaling_legal(model_info, vllm_runner):
hf_overrides = { hf_overrides = {
"rope_parameters": {
"rope_theta": rope_theta, "rope_theta": rope_theta,
"rope_scaling": {
"rope_type": "yarn", "rope_type": "yarn",
"factor": factor, "factor": factor,
"original_max_position_embeddings": original_max_position_embeddings, "original_max_position_embeddings": original_max_position_embeddings,
...@@ -96,9 +98,9 @@ def test_use_rope_scaling_legal(model_info, vllm_runner): ...@@ -96,9 +98,9 @@ def test_use_rope_scaling_legal(model_info, vllm_runner):
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_use_rope_scaling_illegal(model_info, vllm_runner): def test_use_rope_scaling_illegal(model_info, vllm_runner):
hf_overrides = { hf_overrides: dict[str, Any] = {
"rope_parameters": {
"rope_theta": rope_theta, "rope_theta": rope_theta,
"rope_scaling": {
"rope_type": "yarn", "rope_type": "yarn",
"factor": factor, "factor": factor,
"original_max_position_embeddings": original_max_position_embeddings, "original_max_position_embeddings": original_max_position_embeddings,
...@@ -115,8 +117,8 @@ def test_use_rope_scaling_illegal(model_info, vllm_runner): ...@@ -115,8 +117,8 @@ def test_use_rope_scaling_illegal(model_info, vllm_runner):
pass pass
hf_overrides = { hf_overrides = {
"rope_parameters": {
"rope_theta": rope_theta, "rope_theta": rope_theta,
"rope_scaling": {
"rope_type": "yarn", "rope_type": "yarn",
"factor": factor, "factor": factor,
"original_max_position_embeddings": original_max_position_embeddings, "original_max_position_embeddings": original_max_position_embeddings,
......
...@@ -249,45 +249,48 @@ def test_get_bert_tokenization_sentence_transformer_config(): ...@@ -249,45 +249,48 @@ def test_get_bert_tokenization_sentence_transformer_config():
def test_rope_customization(): def test_rope_customization():
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0} TEST_ROPE_PARAMETERS = {
TEST_ROPE_THETA = 16_000_000.0 "rope_theta": 16_000_000.0,
LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0} "rope_type": "dynamic",
"factor": 2.0,
}
LLAMA_ROPE_PARAMETERS = {"rope_theta": 500000.0, "rope_type": "default"}
LONGCHAT_ROPE_PARAMETERS = {"rope_type": "linear", "factor": 8.0}
llama_model_config = ModelConfig("meta-llama/Meta-Llama-3-8B-Instruct") llama_model_config = ModelConfig("meta-llama/Meta-Llama-3-8B-Instruct")
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None assert (
assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000 getattr(llama_model_config.hf_config, "rope_parameters", None)
== LLAMA_ROPE_PARAMETERS
)
assert llama_model_config.max_model_len == 8192 assert llama_model_config.max_model_len == 8192
llama_model_config = ModelConfig( llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct",
hf_overrides={ hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS},
"rope_scaling": TEST_ROPE_SCALING,
"rope_theta": TEST_ROPE_THETA,
},
) )
assert ( assert (
getattr(llama_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING getattr(llama_model_config.hf_config, "rope_parameters", None)
== TEST_ROPE_PARAMETERS
) )
assert getattr(llama_model_config.hf_config, "rope_theta", None) == TEST_ROPE_THETA
assert llama_model_config.max_model_len == 16384 assert llama_model_config.max_model_len == 16384
longchat_model_config = ModelConfig("lmsys/longchat-13b-16k") longchat_model_config = ModelConfig("lmsys/longchat-13b-16k")
# Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config # Check if LONGCHAT_ROPE_PARAMETERS entries are in longchat_model_config
assert all( assert all(
longchat_model_config.hf_config.rope_scaling.get(key) == value longchat_model_config.hf_config.rope_parameters.get(key) == value
for key, value in LONGCHAT_ROPE_SCALING.items() for key, value in LONGCHAT_ROPE_PARAMETERS.items()
) )
assert longchat_model_config.max_model_len == 16384 assert longchat_model_config.max_model_len == 16384
longchat_model_config = ModelConfig( longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k", "lmsys/longchat-13b-16k",
hf_overrides={ hf_overrides={
"rope_scaling": TEST_ROPE_SCALING, "rope_parameters": TEST_ROPE_PARAMETERS,
}, },
) )
assert ( assert (
getattr(longchat_model_config.hf_config, "rope_scaling", None) getattr(longchat_model_config.hf_config, "rope_parameters", None)
== TEST_ROPE_SCALING == TEST_ROPE_PARAMETERS
) )
assert longchat_model_config.max_model_len == 4096 assert longchat_model_config.max_model_len == 4096
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
from pydantic import ConfigDict, SkipValidation, field_validator, model_validator from pydantic import ConfigDict, SkipValidation, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
import vllm.envs as envs import vllm.envs as envs
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
...@@ -2100,30 +2101,31 @@ def _get_and_verify_max_len( ...@@ -2100,30 +2101,31 @@ def _get_and_verify_max_len(
) )
derived_max_model_len = default_max_len derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None) # In Transformers v5 rope_parameters could be TypedDict or dict[str, TypedDict].
# To simplify the verification, we convert it to dict[str, TypedDict].
rope_parameters = getattr(hf_config, "rope_parameters", None)
if rope_parameters and not set(rope_parameters.keys()).issubset(
ALLOWED_LAYER_TYPES
):
rope_parameters = {"": rope_parameters}
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
# scaling, so we skip applying the scaling factor again. # scaling, so we skip applying the scaling factor again.
if rope_scaling is not None and "gemma3" not in hf_config.model_type: if rope_parameters is not None and "gemma3" not in hf_config.model_type:
# No need to consider "type" key because of patch_rope_scaling when scaling_factor = 1.0
for rp in rope_parameters.values():
# No need to consider "type" key because of patch_rope_parameters when
# loading HF config # loading HF config
rope_type = rope_scaling["rope_type"] rope_type = rp["rope_type"]
if rope_type not in ("su", "longrope", "llama3"): if rope_type not in ("su", "longrope", "llama3"):
if disable_sliding_window: # NOTE: rope_type == "default" does not define factor https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
# TODO(robertgshaw): Find a model that supports rope_scaling # NOTE: This assumes all layer types have the same scaling factor.
# with sliding window to see if this case should be allowed. scaling_factor = rp.get("factor", scaling_factor)
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate."
)
# NOTE: rope_type == "default" does not define factor
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
scaling_factor = rope_scaling.get("factor", 1.0)
if rope_type == "yarn": if rope_type == "yarn":
derived_max_model_len = rope_scaling["original_max_position_embeddings"] derived_max_model_len = rp["original_max_position_embeddings"]
# Do this outside loop since all layer types should have the same scaling
derived_max_model_len *= scaling_factor derived_max_model_len *= scaling_factor
if encoder_config and "max_seq_length" in encoder_config: if encoder_config and "max_seq_length" in encoder_config:
...@@ -2134,7 +2136,9 @@ def _get_and_verify_max_len( ...@@ -2134,7 +2136,9 @@ def _get_and_verify_max_len(
if max_model_len is None: if max_model_len is None:
# For LongRoPE, default to original_max_position_embeddings to avoid # For LongRoPE, default to original_max_position_embeddings to avoid
# performance degradation for shorter sequences # performance degradation for shorter sequences
if rope_scaling is not None and rope_scaling["rope_type"] == "longrope": if rope_parameters is not None and any(
rp["rope_type"] == "longrope" for rp in rope_parameters.values()
):
max_model_len = int( max_model_len = int(
getattr( getattr(
hf_config, "original_max_position_embeddings", derived_max_model_len hf_config, "original_max_position_embeddings", derived_max_model_len
...@@ -2151,16 +2155,7 @@ def _get_and_verify_max_len( ...@@ -2151,16 +2155,7 @@ def _get_and_verify_max_len(
# that will be bigger than derived_max_model_len. We compare user input # that will be bigger than derived_max_model_len. We compare user input
# with model_max_length and allow this override when it's smaller. # with model_max_length and allow this override when it's smaller.
model_max_length = getattr(hf_config, "model_max_length", None) model_max_length = getattr(hf_config, "model_max_length", None)
if model_max_length is not None and max_model_len <= model_max_length: if model_max_length is None or max_model_len > model_max_length:
if disable_sliding_window:
# TODO(robertgshaw): Find a model that has model_max_length
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"so we can investigate."
)
else:
msg = ( msg = (
f"User-specified max_model_len ({max_model_len}) is greater " f"User-specified max_model_len ({max_model_len}) is greater "
f"than the derived max_model_len ({max_len_key}=" f"than the derived max_model_len ({max_len_key}="
......
...@@ -26,23 +26,23 @@ def get_rope( ...@@ -26,23 +26,23 @@ def get_rope(
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position: int, max_position: int,
base: float,
is_neox_style: bool = True, is_neox_style: bool = True,
rope_scaling: dict[str, Any] | None = None, rope_parameters: dict[str, Any] | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
partial_rotary_factor: float = 1.0, partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: dict[str, Any] | None = None, dual_chunk_attention_config: dict[str, Any] | None = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if dtype is None: if dtype is None:
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
if rope_scaling is not None: if rope_parameters is not None:
# Transforms every value that is a list into a tuple for caching calls # Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = { rope_parameters_tuple = {
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() k: tuple(v) if isinstance(v, list) else v
for k, v in rope_parameters.items()
} }
rope_scaling_args = tuple(rope_scaling_tuple.items()) rope_parameters_args = tuple(rope_parameters_tuple.items())
else: else:
rope_scaling_args = None rope_parameters_args = None
if dual_chunk_attention_config is not None: if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = { dual_chunk_attention_tuple = {
...@@ -60,15 +60,15 @@ def get_rope( ...@@ -60,15 +60,15 @@ def get_rope(
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
base,
is_neox_style, is_neox_style,
rope_scaling_args, rope_parameters_args,
dual_chunk_attention_args, dual_chunk_attention_args,
dtype, dtype,
) )
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
base = rope_parameters["rope_theta"] if rope_parameters else 10000
if dual_chunk_attention_config is not None: if dual_chunk_attention_config is not None:
extra_kwargs = { extra_kwargs = {
k: v k: v
...@@ -84,18 +84,18 @@ def get_rope( ...@@ -84,18 +84,18 @@ def get_rope(
dtype, dtype,
**extra_kwargs, **extra_kwargs,
) )
elif not rope_scaling: elif not rope_parameters:
rotary_emb = RotaryEmbedding( rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype head_size, rotary_dim, max_position, base, is_neox_style, dtype
) )
else: else:
scaling_type = rope_scaling["rope_type"] scaling_type = rope_parameters["rope_type"]
if scaling_type == "llama3": if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_scaling["low_freq_factor"] low_freq_factor = rope_parameters["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"] high_freq_factor = rope_parameters["high_freq_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"] original_max_position = rope_parameters["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding( rotary_emb = Llama3RotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
...@@ -113,7 +113,7 @@ def get_rope( ...@@ -113,7 +113,7 @@ def get_rope(
head_size, rotary_dim, max_position, base, is_neox_style, dtype head_size, rotary_dim, max_position, base, is_neox_style, dtype
) )
elif scaling_type == "default": elif scaling_type == "default":
if "mrope_section" in rope_scaling: if "mrope_section" in rope_parameters:
rotary_emb = MRotaryEmbedding( rotary_emb = MRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
...@@ -121,8 +121,8 @@ def get_rope( ...@@ -121,8 +121,8 @@ def get_rope(
base, base,
is_neox_style, is_neox_style,
dtype, dtype,
mrope_section=rope_scaling["mrope_section"], mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False), mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
) )
else: else:
rotary_emb = RotaryEmbedding( rotary_emb = RotaryEmbedding(
...@@ -134,7 +134,7 @@ def get_rope( ...@@ -134,7 +134,7 @@ def get_rope(
dtype, dtype,
) )
elif scaling_type == "linear": elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding( rotary_emb = LinearScalingRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
...@@ -145,8 +145,8 @@ def get_rope( ...@@ -145,8 +145,8 @@ def get_rope(
dtype, dtype,
) )
elif scaling_type == "ntk": elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_parameters["factor"]
mixed_b = rope_scaling.get("mixed_b", None) mixed_b = rope_parameters.get("mixed_b")
rotary_emb = NTKScalingRotaryEmbedding( rotary_emb = NTKScalingRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
...@@ -158,8 +158,8 @@ def get_rope( ...@@ -158,8 +158,8 @@ def get_rope(
mixed_b, mixed_b,
) )
elif scaling_type == "dynamic": elif scaling_type == "dynamic":
if "alpha" in rope_scaling: if "alpha" in rope_parameters:
scaling_alpha = rope_scaling["alpha"] scaling_alpha = rope_parameters["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding( rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
...@@ -169,8 +169,8 @@ def get_rope( ...@@ -169,8 +169,8 @@ def get_rope(
scaling_alpha, scaling_alpha,
dtype, dtype,
) )
elif "factor" in rope_scaling: elif "factor" in rope_parameters:
scaling_factor = rope_scaling["factor"] scaling_factor = rope_parameters["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding( rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
...@@ -185,11 +185,11 @@ def get_rope( ...@@ -185,11 +185,11 @@ def get_rope(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field" "Dynamic rope scaling must contain either 'alpha' or 'factor' field"
) )
elif scaling_type == "yarn": elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_parameters["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"] original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = { extra_kwargs = {
k: v k: v
for k, v in rope_scaling.items() for k, v in rope_parameters.items()
if k if k
in ( in (
"extrapolation_factor", "extrapolation_factor",
...@@ -199,7 +199,7 @@ def get_rope( ...@@ -199,7 +199,7 @@ def get_rope(
"apply_yarn_scaling", "apply_yarn_scaling",
) )
} }
if "mrope_section" in rope_scaling: if "mrope_section" in rope_parameters:
extra_kwargs.pop("apply_yarn_scaling", None) extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding( rotary_emb = MRotaryEmbedding(
head_size, head_size,
...@@ -208,8 +208,8 @@ def get_rope( ...@@ -208,8 +208,8 @@ def get_rope(
base, base,
is_neox_style, is_neox_style,
dtype, dtype,
mrope_section=rope_scaling["mrope_section"], mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False), mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
**extra_kwargs, **extra_kwargs,
) )
...@@ -225,12 +225,12 @@ def get_rope( ...@@ -225,12 +225,12 @@ def get_rope(
**extra_kwargs, **extra_kwargs,
) )
elif scaling_type == "deepseek_yarn": elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_parameters["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"] original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor # assert max_position == original_max_position * scaling_factor
extra_kwargs = { extra_kwargs = {
k: v k: v
for k, v in rope_scaling.items() for k, v in rope_parameters.items()
if k if k
in ( in (
"extrapolation_factor", "extrapolation_factor",
...@@ -252,12 +252,12 @@ def get_rope( ...@@ -252,12 +252,12 @@ def get_rope(
**extra_kwargs, **extra_kwargs,
) )
elif scaling_type == "longrope": elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"] short_factor = rope_parameters["short_factor"]
long_factor = rope_scaling["long_factor"] long_factor = rope_parameters["long_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"] original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = { extra_kwargs = {
k: v k: v
for k, v in rope_scaling.items() for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale") if k in ("short_mscale", "long_mscale")
} }
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import islice from itertools import islice
from typing import Any
import torch import torch
from torch import nn from torch import nn
...@@ -171,8 +170,6 @@ class AfmoeAttention(nn.Module): ...@@ -171,8 +170,6 @@ class AfmoeAttention(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 131072, max_position_embeddings: int = 131072,
head_dim: int | None = None, head_dim: int | None = None,
rms_norm_eps: float = 1e-05, rms_norm_eps: float = 1e-05,
...@@ -202,7 +199,6 @@ class AfmoeAttention(nn.Module): ...@@ -202,7 +199,6 @@ class AfmoeAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# Check if this is a local attention layer # Check if this is a local attention layer
...@@ -246,8 +242,7 @@ class AfmoeAttention(nn.Module): ...@@ -246,8 +242,7 @@ class AfmoeAttention(nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, rope_parameters=config["rope_parameters"],
rope_scaling=rope_scaling,
is_neox_style=True, is_neox_style=True,
) )
else: else:
...@@ -303,14 +298,6 @@ class AfmoeDecoderLayer(nn.Module): ...@@ -303,14 +298,6 @@ class AfmoeDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 131072) max_position_embeddings = getattr(config, "max_position_embeddings", 131072)
# DecoderLayers are created with `make_layers` which passes the prefix # DecoderLayers are created with `make_layers` which passes the prefix
...@@ -323,8 +310,6 @@ class AfmoeDecoderLayer(nn.Module): ...@@ -323,8 +310,6 @@ class AfmoeDecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
head_dim=config.head_dim, head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps, rms_norm_eps=config.rms_norm_eps,
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice from itertools import islice
from typing import Any
import torch import torch
from torch import nn from torch import nn
...@@ -118,8 +117,6 @@ class ApertusAttention(nn.Module): ...@@ -118,8 +117,6 @@ class ApertusAttention(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
bias: bool = False, bias: bool = False,
...@@ -155,7 +152,6 @@ class ApertusAttention(nn.Module): ...@@ -155,7 +152,6 @@ class ApertusAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
...@@ -176,9 +172,7 @@ class ApertusAttention(nn.Module): ...@@ -176,9 +172,7 @@ class ApertusAttention(nn.Module):
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
self._init_rotary_emb( self._init_rotary_emb(config, quant_config=quant_config)
config, rope_scaling=rope_scaling, quant_config=quant_config
)
sliding_window = None sliding_window = None
if layer_types := getattr(config, "layer_types", None): if layer_types := getattr(config, "layer_types", None):
...@@ -224,7 +218,6 @@ class ApertusAttention(nn.Module): ...@@ -224,7 +218,6 @@ class ApertusAttention(nn.Module):
def _init_rotary_emb( def _init_rotary_emb(
self, self,
config: ApertusConfig, config: ApertusConfig,
rope_scaling: dict[str, Any] | None,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
) -> None: ) -> None:
is_neox_style = True is_neox_style = True
...@@ -236,8 +229,7 @@ class ApertusAttention(nn.Module): ...@@ -236,8 +229,7 @@ class ApertusAttention(nn.Module):
self.head_dim, self.head_dim,
rotary_dim=int(self.partial_rotary_factor * self.head_dim), rotary_dim=int(self.partial_rotary_factor * self.head_dim),
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
base=self.rope_theta, rope_parameters=config.rope_parameters,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor, partial_rotary_factor=self.partial_rotary_factor,
) )
...@@ -253,14 +245,6 @@ class ApertusDecoderLayer(nn.Module): ...@@ -253,14 +245,6 @@ class ApertusDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias # Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias # Support internlm/internlm-7b with bias
...@@ -288,8 +272,6 @@ class ApertusDecoderLayer(nn.Module): ...@@ -288,8 +272,6 @@ class ApertusDecoderLayer(nn.Module):
num_kv_heads=getattr( num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads config, "num_key_value_heads", config.num_attention_heads
), ),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
bias=attention_bias, bias=attention_bias,
......
...@@ -103,15 +103,6 @@ class ArceeDecoderLayer(nn.Module): ...@@ -103,15 +103,6 @@ class ArceeDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Rotary embedding parameters (reuse LLaMA defaults)
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Determine if attention bias is needed (some variants use bias terms) # Determine if attention bias is needed (some variants use bias terms)
attention_bias = getattr(config, "attention_bias", False) or getattr( attention_bias = getattr(config, "attention_bias", False) or getattr(
...@@ -133,8 +124,6 @@ class ArceeDecoderLayer(nn.Module): ...@@ -133,8 +124,6 @@ class ArceeDecoderLayer(nn.Module):
num_kv_heads=getattr( num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads config, "num_key_value_heads", config.num_attention_heads
), ),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
bias=attention_bias, bias=attention_bias,
......
...@@ -292,7 +292,6 @@ class ArcticAttention(nn.Module): ...@@ -292,7 +292,6 @@ class ArcticAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
...@@ -317,7 +316,7 @@ class ArcticAttention(nn.Module): ...@@ -317,7 +316,7 @@ class ArcticAttention(nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
base=int(self.rope_theta), rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
) )
......
...@@ -136,7 +136,7 @@ class BaiChuanAttention(nn.Module): ...@@ -136,7 +136,7 @@ class BaiChuanAttention(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
position_embedding: str, position_embedding: str,
rope_theta: float = 10000, rope_parameters: dict,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None, cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
...@@ -150,7 +150,6 @@ class BaiChuanAttention(nn.Module): ...@@ -150,7 +150,6 @@ class BaiChuanAttention(nn.Module):
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.position_embedding = position_embedding self.position_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name # pylint: disable=invalid-name
...@@ -192,7 +191,7 @@ class BaiChuanAttention(nn.Module): ...@@ -192,7 +191,7 @@ class BaiChuanAttention(nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
base=self.rope_theta, rope_parameters=rope_parameters,
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = Attention( self.attn = Attention(
...@@ -229,13 +228,12 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -229,13 +228,12 @@ class BaiChuanDecoderLayer(nn.Module):
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = BaiChuanAttention( self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
position_embedding=position_embedding, position_embedding=position_embedding,
rope_theta=rope_theta, rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
......
...@@ -135,9 +135,8 @@ class BailingAttention(nn.Module): ...@@ -135,9 +135,8 @@ class BailingAttention(nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.rotary_dim, rotary_dim=self.rotary_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
base=config.rope_theta, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
rope_scaling=config.rope_scaling,
partial_rotary_factor=self.partial_rotary_factor, partial_rotary_factor=self.partial_rotary_factor,
) )
......
...@@ -156,8 +156,6 @@ class BambaAttentionDecoderLayer(nn.Module): ...@@ -156,8 +156,6 @@ class BambaAttentionDecoderLayer(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -178,7 +176,6 @@ class BambaAttentionDecoderLayer(nn.Module): ...@@ -178,7 +176,6 @@ class BambaAttentionDecoderLayer(nn.Module):
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if hasattr(config, "partial_rotary_factor"): if hasattr(config, "partial_rotary_factor"):
...@@ -192,8 +189,7 @@ class BambaAttentionDecoderLayer(nn.Module): ...@@ -192,8 +189,7 @@ class BambaAttentionDecoderLayer(nn.Module):
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=rotary_dim, rotary_dim=rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_scaling=rope_scaling, rope_parameters=config.rope_parameters,
base=rope_theta,
is_neox_style=True, is_neox_style=True,
dtype=torch.get_default_dtype(), # see impl of get_rope dtype=torch.get_default_dtype(), # see impl of get_rope
) )
......
...@@ -265,8 +265,7 @@ class ChameleonAttention(nn.Module): ...@@ -265,8 +265,7 @@ class ChameleonAttention(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_parameters: dict[str, Any],
rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 4096, max_position_embeddings: int = 4096,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
bias: bool = False, bias: bool = False,
...@@ -293,7 +292,6 @@ class ChameleonAttention(nn.Module): ...@@ -293,7 +292,6 @@ class ChameleonAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
...@@ -318,8 +316,7 @@ class ChameleonAttention(nn.Module): ...@@ -318,8 +316,7 @@ class ChameleonAttention(nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, rope_parameters=rope_parameters,
rope_scaling=rope_scaling,
) )
self.attn = Attention( self.attn = Attention(
...@@ -369,14 +366,6 @@ class ChameleonDecoderLayer(nn.Module): ...@@ -369,14 +366,6 @@ class ChameleonDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 4096) max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
self.self_attn = ChameleonAttention( self.self_attn = ChameleonAttention(
...@@ -385,8 +374,7 @@ class ChameleonDecoderLayer(nn.Module): ...@@ -385,8 +374,7 @@ class ChameleonDecoderLayer(nn.Module):
num_kv_heads=getattr( num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads config, "num_key_value_heads", config.num_attention_heads
), ),
rope_theta=rope_theta, rope_parameters=config.rope_parameters,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
...@@ -439,14 +427,6 @@ class ChameleonSwinDecoderLayer(nn.Module): ...@@ -439,14 +427,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 4096) max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
self.self_attn = ChameleonAttention( self.self_attn = ChameleonAttention(
...@@ -455,8 +435,7 @@ class ChameleonSwinDecoderLayer(nn.Module): ...@@ -455,8 +435,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
num_kv_heads=getattr( num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads config, "num_key_value_heads", config.num_attention_heads
), ),
rope_theta=rope_theta, rope_parameters=config.rope_parameters,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
......
...@@ -99,6 +99,7 @@ class GLMAttention(nn.Module): ...@@ -99,6 +99,7 @@ class GLMAttention(nn.Module):
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 # https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0) rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192) max_positions = getattr(config, "seq_length", 8192)
rope_parameters = {"rope_type": "default", "rope_theta": 10000 * rope_ratio}
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False, # NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True # which is equivalent to is_neox_style=True
is_neox_style = not config.original_rope is_neox_style = not config.original_rope
...@@ -106,7 +107,7 @@ class GLMAttention(nn.Module): ...@@ -106,7 +107,7 @@ class GLMAttention(nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.head_dim // 2, rotary_dim=self.head_dim // 2,
max_position=max_positions, max_position=max_positions,
base=10000 * rope_ratio, rope_parameters=rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
) )
self.attn = Attention( self.attn = Attention(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment