Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a8375f8
Unverified
Commit
7a8375f8
authored
Nov 06, 2025
by
Julien Denize
Committed by
GitHub
Nov 06, 2025
Browse files
Add llama 4 scaling support (#28145)
Signed-off-by:
Julien Denize
<
julien.denize@mistral.ai
>
parent
5e0c1fe6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
8 deletions
+59
-8
vllm/model_executor/layers/rotary_embedding/__init__.py
vllm/model_executor/layers/rotary_embedding/__init__.py
+8
-1
vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py
...del_executor/layers/rotary_embedding/yarn_scaling_rope.py
+6
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+22
-0
vllm/transformers_utils/configs/mistral.py
vllm/transformers_utils/configs/mistral.py
+23
-6
No files found.
vllm/model_executor/layers/rotary_embedding/__init__.py
View file @
7a8375f8
...
...
@@ -191,9 +191,16 @@ def get_rope(
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
)
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"apply_yarn_scaling"
,
)
}
if
"mrope_section"
in
rope_scaling
:
extra_kwargs
.
pop
(
"apply_yarn_scaling"
,
None
)
rotary_emb
=
MRotaryEmbedding
(
head_size
,
rotary_dim
,
...
...
vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py
View file @
7a8375f8
...
...
@@ -27,6 +27,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
apply_yarn_scaling
:
bool
=
True
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
...
...
@@ -34,7 +35,11 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
# Get n-d magnitude scaling corrected for interpolation
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
self
.
mscale
=
(
float
(
yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
if
apply_yarn_scaling
else
float
(
attn_factor
)
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
...
...
vllm/model_executor/models/llama.py
View file @
7a8375f8
...
...
@@ -160,6 +160,14 @@ class LlamaAttention(nn.Module):
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
llama_4_scaling_config
=
getattr
(
config
,
"llama_4_scaling"
,
None
)
self
.
do_llama_4_scaling
=
llama_4_scaling_config
is
not
None
if
self
.
do_llama_4_scaling
:
self
.
llama_4_scaling_original_max_position_embeddings
=
(
llama_4_scaling_config
[
"original_max_position_embeddings"
]
)
self
.
llama_4_scaling_beta
=
llama_4_scaling_config
[
"beta"
]
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
...
...
@@ -221,6 +229,17 @@ class LlamaAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.attn"
,
)
def
_get_llama_4_attn_scale
(
self
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Llama4 scaling
scaling
=
1
+
self
.
llama_4_scaling_beta
*
torch
.
log
(
1
+
torch
.
floor
(
positions
/
self
.
llama_4_scaling_original_max_position_embeddings
)
)
# Broadcast over head_dim
return
scaling
.
unsqueeze
(
-
1
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -229,6 +248,9 @@ class LlamaAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
if
self
.
do_llama_4_scaling
:
attn_scale
=
self
.
_get_llama_4_attn_scale
(
positions
)
q
=
(
q
*
attn_scale
).
to
(
q
.
dtype
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
vllm/transformers_utils/configs/mistral.py
View file @
7a8375f8
...
...
@@ -24,6 +24,18 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig
if
bool
(
config_dict
.
get
(
"yarn"
)):
config_dict
=
_remap_mistral_yarn_args
(
config_dict
)
if
bool
(
config_dict
.
get
(
"llama_4_scaling"
)):
llama_4_scaling_config_keys
=
[
"original_max_position_embeddings"
,
"beta"
]
assert
all
(
[
key
in
config_dict
[
"llama_4_scaling"
]
for
key
in
llama_4_scaling_config_keys
]
),
(
"llama_4_scaling config should define the keys: "
f
"
{
','
.
join
(
llama_4_scaling_config_keys
)
}
"
)
is_vision
=
(
config_dict
.
get
(
"multimodal"
)
or
{}).
get
(
"vision_encoder_args"
)
or
config_dict
.
get
(
"vision_encoder"
)
...
...
@@ -66,19 +78,24 @@ def _remap_mistral_vision_args(config: dict) -> dict:
def
_remap_mistral_yarn_args
(
config
:
dict
)
->
dict
:
# Direct remaps: yarn.X -> rope_scaling.Y
# Source keys are from mistral.model.args.YarnArgs
_map
=
{
yarn_config_map
=
{
"factor"
:
"factor"
,
"original_max_position_embeddings"
:
"original_max_position_embeddings"
,
"beta"
:
"beta_fast"
,
"alpha"
:
"beta_slow"
,
"apply_scale"
:
"apply_yarn_scaling"
,
}
yarn_config
=
config
.
get
(
"yarn"
)
or
{}
renamed_yarn_config
=
{
_map
.
get
(
k
,
k
):
v
for
k
,
v
in
yarn_config
.
items
()}
config
[
"rope_scaling"
]
=
{
"rope_type"
:
"yarn"
,
"mscale_all_dim"
:
1
,
# We hardcoded this to 1
**
renamed_yarn_config
,
"mscale_all_dim"
:
1
,
}
for
old_name
,
new_name
in
yarn_config_map
.
items
():
if
old_name
in
yarn_config
:
config
[
"rope_scaling"
][
new_name
]
=
yarn_config
.
pop
(
old_name
)
assert
len
(
yarn_config
)
==
0
,
f
"Unparsed yarn config:
{
yarn_config
}
"
return
config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment