Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a810671a
Commit
a810671a
authored
Jan 08, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori
parents
86b5aefe
6a09612b
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
362 additions
and
87 deletions
+362
-87
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+1
-1
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+29
-14
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+5
-5
vllm/model_executor/layers/quantization/quark/quark.py
vllm/model_executor/layers/quantization/quark/quark.py
+43
-0
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+158
-2
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+30
-30
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+5
-2
vllm/model_executor/layers/rotary_embedding/base.py
vllm/model_executor/layers/rotary_embedding/base.py
+4
-1
vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
...executor/layers/rotary_embedding/deepseek_scaling_rope.py
+20
-1
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+29
-6
vllm/model_executor/models/aimv2.py
vllm/model_executor/models/aimv2.py
+2
-2
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+3
-1
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+2
-2
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+6
-5
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+0
-6
vllm/model_executor/models/deepencoder.py
vllm/model_executor/models/deepencoder.py
+2
-2
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+1
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+5
-1
vllm/model_executor/models/gemma3.py
vllm/model_executor/models/gemma3.py
+15
-4
vllm/model_executor/models/glm4v.py
vllm/model_executor/models/glm4v.py
+2
-2
No files found.
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
a810671a
...
...
@@ -181,7 +181,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
return
75
@
classmethod
def
get_config_filenames
(
cls
)
->
list
[
str
]:
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
a810671a
...
...
@@ -871,7 +871,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
return
75
@
classmethod
def
override_quantization_method
(
...
...
@@ -1458,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
logger
.
debug_once
(
"Finished shuffling weights for TRT-LLM MOE"
)
layer
.
gemm1
_weight
s_fp4_shuffled
=
Parameter
(
layer
.
w13
_weight
=
Parameter
(
gemm1_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm2_weights_fp4_shuffled
=
Parameter
(
gemm2_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm1_scales_fp4_shuffled
=
Parameter
(
layer
.
w2_weight
=
Parameter
(
gemm2_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
gemm1_scales_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm2_scales_fp4_shuff
le
d
=
Parameter
(
layer
.
w2_weight_sca
le
=
Parameter
(
gemm2_scales_fp4_shuffled
,
requires_grad
=
False
)
...
...
@@ -1476,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
(
layer
.
w2_input_scale_quant
*
layer
.
g1_alphas
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
# Clean up weights that won't be used by TRT-LLM
del
layer
.
w2_weight
del
layer
.
w2_weight_scale
del
layer
.
w13_weight
del
layer
.
w13_weight_scale
elif
self
.
use_marlin
:
# Marlin processing
prepare_moe_fp4_layer_for_marlin
(
layer
)
...
...
@@ -1530,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w2_blockscale_swizzled
,
requires_grad
=
False
)
def
prepare_dp_allgather_tensor
(
self
,
layer
:
FusedMoE
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
import
flashinfer
a1_gscale
=
layer
.
w13_input_scale_quant
hidden_states_fp4
,
hidden_states_sf
=
flashinfer
.
fp4_quantize
(
hidden_states
,
a1_gscale
,
is_sf_swizzled_layout
=
False
,
)
extra_tensors
:
list
[
torch
.
Tensor
]
=
[
hidden_states_sf
]
return
hidden_states_fp4
,
extra_tensors
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
...
...
@@ -1584,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias
=
layer
.
e_score_correction_bias
,
)
# Hidden_states in select_experts is only used to extract metadata
if
isinstance
(
x
,
tuple
):
x_routing
,
_
=
x
else
:
x_routing
=
x
topk_weights
,
topk_ids
,
_
=
layer
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
_routing
,
router_logits
=
router_logits
,
)
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
a810671a
...
...
@@ -95,12 +95,12 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and
(
9
,
0
)
<=
current_platform
.
get_device_capability
()
<
(
11
,
0
)
)
if
envs
.
VLLM_MXFP4_USE_MARLIN
or
not
triton_kernels_supported
:
logger
.
info_once
(
"[get_mxfp4_backend_with_lora] Using
Marli
n backend"
)
return
Mxfp4Backend
.
MARLI
N
if
envs
.
VLLM_MXFP4_USE_MARLIN
is
False
and
triton_kernels_supported
:
logger
.
info_once
(
"[get_mxfp4_backend_with_lora] Using
Trito
n backend"
)
return
Mxfp4Backend
.
TRITO
N
logger
.
info_once
(
"[get_mxfp4_backend_with_lora] Using
Trito
n backend"
)
return
Mxfp4Backend
.
TRITO
N
logger
.
info_once
(
"[get_mxfp4_backend_with_lora] Using
Marli
n backend"
)
return
Mxfp4Backend
.
MARLI
N
def
get_mxfp4_backend
(
with_lora_support
:
bool
)
->
Mxfp4Backend
:
...
...
vllm/model_executor/layers/quantization/quark/quark.py
View file @
a810671a
...
...
@@ -218,6 +218,49 @@ class QuarkConfig(QuantizationConfig):
else
:
return
False
def
_is_fp8_w4a8
(
self
,
weight_quant
:
list
[
dict
[
str
,
Any
]]
|
None
,
input_quant
:
dict
[
str
,
Any
]
|
None
,
)
->
bool
:
# Confirm weights and input quantized.
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
if
not
isinstance
(
weight_quant
,
list
)
or
len
(
weight_quant
)
!=
2
:
return
False
# Confirm weight scheme is supported
is_w4a8_dtype
=
(
weight_quant
[
0
].
get
(
"dtype"
)
==
"fp8_e4m3"
and
weight_quant
[
1
].
get
(
"dtype"
)
==
"int4"
and
input_quant
.
get
(
"dtype"
)
==
"fp8_e4m3"
)
is_static_weight
=
not
weight_quant
[
0
].
get
(
"is_dynamic"
)
and
not
weight_quant
[
1
].
get
(
"is_dynamic"
)
is_per_tensor_fp8_and_per_channel_int4_weight
=
(
weight_quant
[
0
].
get
(
"qscheme"
)
==
"per_tensor"
and
weight_quant
[
1
].
get
(
"qscheme"
)
==
"per_channel"
and
weight_quant
[
1
].
get
(
"symmetric"
)
is
True
and
weight_quant
[
1
].
get
(
"ch_axis"
)
==
0
)
if
not
(
is_w4a8_dtype
and
is_static_weight
and
is_per_tensor_fp8_and_per_channel_int4_weight
):
return
False
# Dynamic quantization is always supported if weights supported.
if
input_quant
.
get
(
"is_dynamic"
):
return
True
# Confirm activation scheme is supported.
is_per_tensor_activation
=
input_quant
.
get
(
"qscheme"
)
==
"per_tensor"
return
is_per_tensor_activation
def
_is_fp8_w8a8
(
self
,
weight_quant
:
dict
[
str
,
Any
]
|
None
,
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
a810671a
...
...
@@ -63,8 +63,9 @@ class QuarkMoEMethod(FusedMoEMethodBase):
)
weight_config
=
layer_quant_config
.
get
(
"weight"
)
input_config
=
layer_quant_config
.
get
(
"input_tensors"
)
if
quant_config
.
_is_fp8_w8a8
(
weight_config
,
input_config
):
if
quant_config
.
_is_fp8_w4a8
(
weight_config
,
input_config
):
return
QuarkW4A8Fp8MoEMethod
(
weight_config
,
input_config
,
module
.
moe_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_config
,
input_config
):
return
QuarkW8A8Fp8MoEMethod
(
weight_config
,
input_config
,
module
.
moe_config
)
elif
quant_config
.
_is_ocp_mx
(
weight_config
,
input_config
):
return
QuarkOCP_MX_MoEMethod
(
weight_config
,
input_config
,
module
.
moe_config
)
...
...
@@ -396,6 +397,161 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
class
QuarkW4A8Fp8MoEMethod
(
QuarkMoEMethod
):
def
__init__
(
self
,
weight_config
:
dict
[
str
,
Any
],
input_config
:
dict
[
str
,
Any
],
moe
:
FusedMoEConfig
,
):
super
().
__init__
(
moe
)
self
.
weight_quant
=
weight_config
self
.
input_quant
=
input_config
assert
rocm_aiter_ops
.
is_fused_moe_enabled
(),
(
"W4A8 FP8 MoE requires ROCm AITER fused MoE support."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
params_dtype
=
torch
.
uint32
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
//
8
,
# INT32 packing for W4
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
8
,
# INT32 packing for W4
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# Per-tensor fp8 weight scales
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# Per-channel int4 weight scales
w13_weight_scale_2
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
w2_weight_scale_2
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale_2"
,
w13_weight_scale_2
)
layer
.
register_parameter
(
"w2_weight_scale_2"
,
w2_weight_scale_2
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale_2
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale_2
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
shuffled_w13
,
shuffled_w2
=
rocm_aiter_ops
.
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
assert
torch
.
all
(
max_w13_scales
!=
0
),
"fp8 weight scale cannot be zero."
for
expert_id
in
range
(
layer
.
local_num_experts
):
start
=
0
max_w13_scale_fp8
=
max_w13_scales
[
expert_id
]
for
shard_id
in
range
(
2
):
if
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
!=
max_w13_scale_fp8
:
int4_rescale
=
(
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
/
max_w13_scale_fp8
)
layer
.
w13_weight_scale_2
[
expert_id
][
start
:
start
+
shard_size
]
*=
(
int4_rescale
)
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post
# GEMM scaling optimal design - shall apply per-column weight_scale1 before
# GEMM, and weight_scale post
for
expert_id
in
range
(
layer
.
local_num_experts
):
layer
.
w13_weight_scale_2
[
expert_id
]
*=
max_w13_scales
[
expert_id
]
layer
.
w2_weight_scale_2
[
expert_id
]
*=
layer
.
w2_weight_scale
[
expert_id
]
def
get_fused_moe_quant_config
(
self
,
layer
):
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale_2
,
w2_scale
=
layer
.
w2_weight_scale_2
,
per_out_ch_quant
=
True
,
)
def
apply
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
topk_weights
,
topk_ids
,
_
=
layer
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
)
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
)
return
rocm_aiter_fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
quant_config
=
self
.
moe_quant_config
,
expert_map
=
layer
.
expert_map
,
)
class
QuarkOCP_MX_MoEMethod
(
QuarkMoEMethod
):
def
__init__
(
self
,
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
View file @
a810671a
...
...
@@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
def
flashinfer_trtllm_fp4_moe
(
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
global_num_experts
:
int
,
...
...
@@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe(
from
vllm.model_executor.models.llama4
import
Llama4MoE
# Quantize input to FP4
a1_gscale
=
layer
.
w13_input_scale_quant
(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
)
=
flashinfer
.
fp4_quantize
(
x
,
a1_gscale
,
is_sf_swizzled_layout
=
False
,
)
if
isinstance
(
x
,
tuple
):
hidden_states_fp4
,
hidden_states_scale_linear_fp4
=
x
else
:
# hidden_states is the already quantized
a1_gscale
=
layer
.
w13_input_scale_quant
(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
)
=
flashinfer
.
fp4_quantize
(
x
,
a1_gscale
,
is_sf_swizzled_layout
=
False
,
)
# Determine routing method type
use_llama4_routing
=
custom_routing_function
is
Llama4MoE
.
custom_routing_function
...
...
@@ -301,18 +305,14 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_scale
=
hidden_states_scale_linear_fp4
.
view
(
torch
.
float8_e4m3fn
).
flatten
(),
gemm1_weights
=
layer
.
gemm1_weights_fp4_shuffled
.
data
,
gemm1_weights_scale
=
layer
.
gemm1_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm1_weights
=
layer
.
w13_weight
.
data
,
gemm1_weights_scale
=
layer
.
w13_weight_scale
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm1_bias
=
None
,
gemm1_alpha
=
None
,
gemm1_beta
=
None
,
gemm1_clamp_limit
=
None
,
gemm2_weights
=
layer
.
gemm2_weights_fp4_shuffled
.
data
,
gemm2_weights_scale
=
layer
.
gemm2_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_weights
=
layer
.
w2_weight
.
data
,
gemm2_weights_scale
=
layer
.
w2_weight_scale
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_bias
=
None
,
output1_scale_scalar
=
layer
.
g1_scale_c
.
data
,
output1_scale_gate_scalar
=
layer
.
g1_alphas
.
data
,
...
...
@@ -364,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe(
torch
.
bfloat16
).
view
(
torch
.
int16
)
# Quantize input to FP4
a1_gscale
=
layer
.
w13_input_scale_quant
(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
)
=
flashinfer
.
fp4_quantize
(
x
,
a1_gscale
,
is_sf_swizzled_layout
=
False
,
)
if
isinstance
(
x
,
tuple
):
# Hidden_states is the already quantized
hidden_states_fp4
,
hidden_states_scale_linear_fp4
=
x
else
:
# Quantize input to FP4
a1_gscale
=
layer
.
w13_input_scale_quant
(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
)
=
flashinfer
.
fp4_quantize
(
x
,
a1_gscale
,
is_sf_swizzled_layout
=
False
,
)
# Call TRT-LLM FP4 block-scale MoE kernel
out
=
flashinfer
.
fused_moe
.
trtllm_fp4_block_scale_routed_moe
(
...
...
@@ -380,18 +384,14 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_scale
=
hidden_states_scale_linear_fp4
.
view
(
torch
.
float8_e4m3fn
).
flatten
(),
gemm1_weights
=
layer
.
gemm1_weights_fp4_shuffled
.
data
,
gemm1_weights_scale
=
layer
.
gemm1_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm1_weights
=
layer
.
w13_weight
.
data
,
gemm1_weights_scale
=
layer
.
w13_weight_scale
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm1_bias
=
None
,
gemm1_alpha
=
None
,
gemm1_beta
=
None
,
gemm1_clamp_limit
=
None
,
gemm2_weights
=
layer
.
gemm2_weights_fp4_shuffled
.
data
,
gemm2_weights_scale
=
layer
.
gemm2_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_weights
=
layer
.
w2_weight
.
data
,
gemm2_weights_scale
=
layer
.
w2_weight_scale
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_bias
=
None
,
output1_scale_scalar
=
layer
.
g1_scale_c
.
data
,
output1_scale_gate_scalar
=
layer
.
g1_alphas
.
data
,
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
a810671a
...
...
@@ -1437,14 +1437,17 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
layer
.
orig_dtype
,
layer
.
weight
)
if
should_use_deepgemm
:
scale_attr
=
(
"weight_scale_inv"
if
hasattr
(
layer
,
"weight_scale_inv"
)
else
"weight_scale"
)
dg_weight
,
dg_weight_scale
=
deepgemm_post_process_fp8_weight_block
(
wq
=
layer
.
weight
.
data
,
ws
=
layer
.
weight_
scale_
inv
.
data
,
ws
=
getattr
(
layer
,
scale_
attr
)
.
data
,
quant_block_shape
=
tuple
(
layer
.
weight_block_size
),
use_e8m0
=
is_deep_gemm_e8m0_used
(),
)
replace_parameter
(
layer
,
"weight"
,
dg_weight
)
replace_parameter
(
layer
,
"weight_
scale_
inv"
,
dg_weight_scale
)
replace_parameter
(
layer
,
scale_
attr
,
dg_weight_scale
)
def
expert_weight_is_col_major
(
x
:
torch
.
Tensor
)
->
bool
:
...
...
vllm/model_executor/layers/rotary_embedding/base.py
View file @
a810671a
...
...
@@ -38,7 +38,10 @@ class RotaryEmbeddingBase(CustomOp):
# and current_platform.is_cuda()
# and has_flashinfer()
# and self.head_size in [64, 128, 256, 512])
self
.
use_flashinfer
=
False
# Check if use_flashinfer is already set
if
not
hasattr
(
self
,
"use_flashinfer"
):
self
.
use_flashinfer
=
False
cache
=
self
.
_compute_cos_sin_cache
()
if
not
self
.
use_flashinfer
:
...
...
vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
View file @
a810671a
...
...
@@ -6,6 +6,7 @@ import math
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer
from
.base
import
RotaryEmbeddingBase
from
.common
import
(
...
...
@@ -56,6 +57,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
/
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale_all_dim
))
*
attn_factor
)
self
.
use_flashinfer
=
(
self
.
enabled
()
and
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
)
and
current_platform
.
is_cuda
()
and
has_flashinfer
()
and
head_size
in
[
64
,
128
,
256
,
512
]
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
...
...
@@ -162,4 +170,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key
:
torch
.
Tensor
|
None
=
None
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
if
self
.
use_flashinfer
:
torch
.
ops
.
vllm
.
flashinfer_rotary_embedding
(
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
)
return
query
,
key
else
:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
vllm/model_executor/model_loader/weight_utils.py
View file @
a810671a
...
...
@@ -23,6 +23,7 @@ import torch
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load
,
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm
import
envs
from
vllm.config
import
ModelConfig
...
...
@@ -448,12 +449,31 @@ def download_weights_from_hf(
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
# Use the first pattern found in the HF repo's files.
for
pattern
in
allow_patterns
:
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
if
len
(
matching
)
>
0
:
allow_patterns
=
[
pattern
]
break
# If downloading safetensors and an index file exists, use the
# specific file names from the index to avoid downloading
# unnecessary files (e.g., from subdirectories like "original/").
index_file
=
f
"
{
model_name_or_path
}
/
{
SAFE_WEIGHTS_INDEX_NAME
}
"
if
"*.safetensors"
in
allow_patterns
and
index_file
in
file_list
:
index_path
=
hf_hub_download
(
repo_id
=
model_name_or_path
,
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
cache_dir
=
cache_dir
,
revision
=
revision
,
)
with
open
(
index_path
)
as
f
:
weight_map
=
json
.
load
(
f
)[
"weight_map"
]
if
weight_map
:
# Extra [] so that weight_map files are treated as a
# single allow_pattern in the loop below
allow_patterns
=
[
list
(
set
(
weight_map
.
values
()))]
# type: ignore[list-item]
else
:
allow_patterns
=
[
"*.safetensors"
]
else
:
# Use the first pattern found in the HF repo's files.
for
pattern
in
allow_patterns
:
if
fnmatch
.
filter
(
file_list
,
pattern
):
allow_patterns
=
[
pattern
]
break
except
Exception
as
e
:
logger
.
warning
(
"Failed to get file list for '%s'. Trying each pattern in "
...
...
@@ -480,6 +500,9 @@ def download_weights_from_hf(
)
# If we have downloaded weights for this allow_pattern,
# we don't need to check the rest.
# allow_pattern can be a list (from weight_map) or str (glob)
if
isinstance
(
allow_pattern
,
list
):
break
if
any
(
Path
(
hf_folder
).
glob
(
allow_pattern
)):
break
time_taken
=
time
.
perf_counter
()
-
start_time
...
...
vllm/model_executor/models/aimv2.py
View file @
a810671a
...
...
@@ -8,7 +8,7 @@ from collections.abc import Iterable
import
torch
import
torch.nn
as
nn
from
vllm.attention.layer
import
M
ultiHead
Attention
from
vllm.attention.layer
s.mm_encoder_attention
import
M
MEncoder
Attention
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.utils
import
divide
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
...
@@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
attn
=
M
ultiHead
Attention
(
self
.
attn
=
M
MEncoder
Attention
(
self
.
num_heads_per_partition
,
self
.
head_dim
,
self
.
scale
)
...
...
vllm/model_executor/models/bert.py
View file @
a810671a
...
...
@@ -55,7 +55,9 @@ class BertEmbedding(nn.Module):
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
unsqueeze
(
0
),
)
self
.
position_embedding_type
=
config
.
position_embedding_type
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
if
self
.
position_embedding_type
!=
"absolute"
:
raise
ValueError
(
"Only 'absolute' position_embedding_type"
+
" is supported"
...
...
vllm/model_executor/models/blip.py
View file @
a810671a
...
...
@@ -9,7 +9,7 @@ import torch
import
torch.nn
as
nn
from
transformers
import
Blip2VisionConfig
,
BlipVisionConfig
from
vllm.attention.layer
import
M
ultiHead
Attention
from
vllm.attention.layer
s.mm_encoder_attention
import
M
MEncoder
Attention
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.conv
import
Conv2dLayer
...
...
@@ -122,7 +122,7 @@ class BlipAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
attn
=
M
ultiHead
Attention
(
self
.
attn
=
M
MEncoder
Attention
(
self
.
num_heads_per_partition
,
self
.
head_dim
,
self
.
scale
)
...
...
vllm/model_executor/models/clip.py
View file @
a810671a
...
...
@@ -14,7 +14,8 @@ from transformers import (
CLIPVisionConfig
,
)
from
vllm.attention.layer
import
Attention
,
MultiHeadAttention
from
vllm.attention.layer
import
Attention
from
vllm.attention.layers.mm_encoder_attention
import
MMEncoderAttention
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
...
...
@@ -354,7 +355,7 @@ class CLIPAttention(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
*
,
prefix
:
str
=
""
,
attn_cls
:
type
[
Attention
]
|
type
[
M
ultiHead
Attention
],
attn_cls
:
type
[
Attention
]
|
type
[
M
MEncoder
Attention
],
)
->
None
:
super
().
__init__
()
...
...
@@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
*
,
prefix
:
str
=
""
,
attn_cls
:
type
[
Attention
]
|
type
[
M
ultiHead
Attention
],
attn_cls
:
type
[
Attention
]
|
type
[
M
MEncoder
Attention
],
)
->
None
:
super
().
__init__
()
self
.
self_attn
=
CLIPAttention
(
...
...
@@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module):
num_hidden_layers_override
:
int
|
None
=
None
,
*
,
prefix
:
str
=
""
,
attn_cls
:
type
[
Attention
]
|
type
[
M
ultiHead
Attention
],
attn_cls
:
type
[
Attention
]
|
type
[
M
MEncoder
Attention
],
)
->
None
:
super
().
__init__
()
...
...
@@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module):
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
,
prefix
=
f
"
{
prefix
}
.encoder"
,
attn_cls
=
M
ultiHead
Attention
,
attn_cls
=
M
MEncoder
Attention
,
)
num_hidden_layers
=
config
.
num_hidden_layers
...
...
vllm/model_executor/models/config.py
View file @
a810671a
...
...
@@ -308,12 +308,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
if
cache_config
.
mamba_block_size
is
None
:
cache_config
.
mamba_block_size
=
model_config
.
max_model_len
# TODO(tdoublep): remove once cascade attention is supported
logger
.
info
(
"Disabling cascade attention since it is not supported for hybrid models."
)
model_config
.
disable_cascade_attn
=
True
class
HybridAttentionMambaModelConfig
(
VerifyAndUpdateConfig
):
@
classmethod
...
...
vllm/model_executor/models/deepencoder.py
View file @
a810671a
...
...
@@ -18,7 +18,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
transformers
import
CLIPVisionConfig
from
vllm.attention.layer
import
M
ultiHead
Attention
from
vllm.attention.layer
s.mm_encoder_attention
import
M
MEncoder
Attention
from
vllm.model_executor.layers.conv
import
Conv2dLayer
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
,
prefix
=
f
"
{
prefix
}
.encoder"
,
attn_cls
=
M
ultiHead
Attention
,
attn_cls
=
M
MEncoder
Attention
,
)
num_hidden_layers
=
config
.
num_hidden_layers
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
a810671a
...
...
@@ -141,6 +141,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"embed_tokens"
),
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
a810671a
...
...
@@ -837,7 +837,11 @@ class Indexer(nn.Module):
)
self
.
k_norm
=
LayerNorm
(
self
.
head_dim
,
eps
=
1e-6
)
self
.
weights_proj
=
ReplicatedLinear
(
hidden_size
,
self
.
n_head
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.weights_proj"
hidden_size
,
self
.
n_head
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.weights_proj"
,
)
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
...
...
vllm/model_executor/models/gemma3.py
View file @
a810671a
...
...
@@ -38,7 +38,10 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
...
...
@@ -463,12 +466,20 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super
().
__init__
()
self
.
config
=
config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
quant_config
=
quant_config
self
.
model
=
Gemma3Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
soft_cap
=
config
.
final_logit_softcapping
)
...
...
@@ -496,7 +507,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
,
hidden_states
)
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
...
vllm/model_executor/models/glm4v.py
View file @
a810671a
...
...
@@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from
transformers.image_utils
import
ImageInput
from
transformers.tokenization_utils_base
import
TextInput
from
vllm.attention.layer
import
M
ultiHead
Attention
from
vllm.attention.layer
s.mm_encoder_attention
import
M
MEncoder
Attention
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.dense"
,
)
self
.
attn
=
M
ultiHead
Attention
(
self
.
attn
=
M
MEncoder
Attention
(
self
.
num_heads_per_rank
,
self
.
head_dim
,
self
.
scale
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
config
.
dropout_prob
)
...
...
Prev
1
…
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment