Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3f8d5dd
Commit
a3f8d5dd
authored
Dec 17, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori
parents
8d75f22e
f34eca5f
Changes
499
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
572 additions
and
247 deletions
+572
-247
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+23
-18
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+4
-1
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+72
-0
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+0
-2
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+1
-1
vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
+17
-6
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+5
-6
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
...del_executor/layers/fused_moe/fused_moe_modular_method.py
+1
-6
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+19
-5
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+29
-65
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+2
-2
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+94
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+0
-3
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
...pressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
+1
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+157
-92
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+6
-0
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+109
-1
vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
...rs/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
+4
-1
vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
...xecutor/layers/quantization/kernels/scaled_mm/__init__.py
+13
-29
vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
...l_executor/layers/quantization/kernels/scaled_mm/aiter.py
+15
-7
No files found.
vllm/model_executor/layers/batch_invariant.py
View file @
a3f8d5dd
...
...
@@ -6,7 +6,7 @@ from typing import Any
import
torch
import
vllm.envs
as
envs
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
...
...
@@ -936,7 +936,7 @@ def enable_batch_invariant_mode():
# Batch invariant matmuls are no longer needed after cublas overrides
if
not
is_torch_equal_or_newer
(
"2.10.0.dev"
):
if
(
current_platform
.
is_device_capability
(
100
)
current_platform
.
is_device_capability
_family
(
100
)
or
current_platform
.
is_device_capability
(
80
)
or
current_platform
.
is_device_capability
(
89
)
):
...
...
@@ -1004,27 +1004,30 @@ def vllm_is_batch_invariant() -> bool:
return
VLLM_BATCH_INVARIANT
def
override_envs_for_invariance
():
curr_attn_backend
=
envs
.
VLLM_ATTENTION_BACKEND
def
override_envs_for_invariance
(
attention_backend
:
AttentionBackendEnum
|
None
,
):
supported_backends
=
[
"
FLASH_ATTN
"
,
# best supported backend
"
FLASHINFER
"
,
"
FLASH_ATTN_MLA
"
,
"
TRITON_MLA
"
,
AttentionBackendEnum
.
FLASH_ATTN
,
# best supported backend
AttentionBackendEnum
.
FLASHINFER
,
AttentionBackendEnum
.
FLASH_ATTN_MLA
,
AttentionBackendEnum
.
TRITON_MLA
,
# Not yet supported MLA backends
#
"
FLASHMLA
"
,
#
"
FLEX_ATTENTION
"
, # IMA issue
even if we disable batch invariance
#
"
FLASHINFER_MLA
"
,
https://github.com/vllm-project/vllm/pull/
28967
#
AttentionBackendEnum.
FLASHMLA,
#
AttentionBackendEnum.
FLEX_ATTENTION,
# IMA issue
#
AttentionBackendEnum.
FLASHINFER_MLA,
# PR #
28967
]
if
curr_attn_backend
not
in
supported_backends
:
if
attention_backend
not
in
supported_backends
:
supported_names
=
[
b
.
name
for
b
in
supported_backends
]
backend_name
=
attention_backend
.
name
if
attention_backend
else
None
error
=
(
"VLLM batch_invariant mode requires an attention backend in "
f
"
{
supported_
backend
s
}
, but got '
{
curr_attn_
backend
}
'. "
"Please se
t the 'VLLM_ATTENTION_BACKEND' environment variable
"
"
to
one of the supported backends before enabling batch_invariant."
f
"
{
supported_
name
s
}
, but got '
{
backend
_name
}
'. "
"Please
u
se
--attention-backend or attention_config to set
"
"one of the supported backends before enabling batch_invariant."
)
raise
RuntimeError
(
error
)
if
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
!=
supported_backends
[
0
]:
if
attention_backend
!=
supported_backends
[
0
]:
warning
=
(
"You are using a decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
...
...
@@ -1050,10 +1053,12 @@ def override_envs_for_invariance():
os
.
environ
[
"VLLM_USE_AOT_COMPILE"
]
=
"0"
def
init_batch_invariance
():
def
init_batch_invariance
(
attention_backend
:
AttentionBackendEnum
|
None
,
):
# this will hit all the csrc overrides as well
if
vllm_is_batch_invariant
():
override_envs_for_invariance
()
override_envs_for_invariance
(
attention_backend
)
enable_batch_invariant_mode
()
# Disable TF32 for batch invariance - it causes non-deterministic rounding
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
a3f8d5dd
...
...
@@ -287,7 +287,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
"""
return
is_deep_gemm_e8m0_used
()
and
current_platform
.
is_device_capability
(
100
)
return
(
is_deep_gemm_e8m0_used
()
and
current_platform
.
is_device_capability_family
(
100
)
)
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# Let PrepareAndFinalize::finalize() decide the impl.
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
a3f8d5dd
...
...
@@ -543,6 +543,42 @@ def int8_w8a8_moe_quant_config(
)
def
gptq_marlin_moe_quant_config
(
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
weight_bits
:
int
,
group_size
:
int
,
w1_zp
:
torch
.
Tensor
|
None
=
None
,
w2_zp
:
torch
.
Tensor
|
None
=
None
,
w1_bias
:
torch
.
Tensor
|
None
=
None
,
w2_bias
:
torch
.
Tensor
|
None
=
None
,
):
"""
Construct a quant config for gptq marlin quantization.
"""
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
w_shape
=
None
if
group_size
==
-
1
else
GroupShape
(
row
=
1
,
col
=
group_size
)
# Activations are NOT quantized for GPTQ (fp16/bf16)
a_shape
=
w_shape
# Same as weight shape for alignment
# Determine weight dtype
if
weight_bits
==
4
:
weight_dtype
=
"int4"
elif
weight_bits
==
8
:
weight_dtype
=
torch
.
int8
else
:
raise
ValueError
(
f
"Unsupported weight_bits:
{
weight_bits
}
"
)
return
FusedMoEQuantConfig
(
_a1
=
FusedMoEQuantDesc
(
dtype
=
None
,
shape
=
a_shape
),
_a2
=
FusedMoEQuantDesc
(
dtype
=
None
,
shape
=
a_shape
),
_w1
=
FusedMoEQuantDesc
(
weight_dtype
,
w_shape
,
w1_scale
,
None
,
w1_zp
,
w1_bias
),
_w2
=
FusedMoEQuantDesc
(
weight_dtype
,
w_shape
,
w2_scale
,
None
,
w2_zp
,
w2_bias
),
)
def
mxfp4_w4a16_moe_quant_config
(
w1_scale
:
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
w2_scale
:
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
...
...
@@ -700,6 +736,42 @@ def int4_w4afp8_moe_quant_config(
)
def
awq_marlin_moe_quant_config
(
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w1_zp
:
torch
.
Tensor
|
None
,
w2_zp
:
torch
.
Tensor
|
None
,
weight_bits
:
int
,
group_size
:
int
,
w1_bias
:
torch
.
Tensor
|
None
=
None
,
w2_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
FusedMoEQuantConfig
:
"""
Construct a quant config for awq marlin quantization.
"""
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
w_shape
=
None
if
group_size
==
-
1
else
GroupShape
(
row
=
1
,
col
=
group_size
)
# Activations are NOT quantized for AWQ (fp16/bf16)
a_shape
=
w_shape
# Same as weight shape for alignment
# Determine weight dtype
if
weight_bits
==
4
:
weight_dtype
=
"int4"
elif
weight_bits
==
8
:
weight_dtype
=
torch
.
int8
else
:
raise
ValueError
(
f
"Unsupported weight_bits:
{
weight_bits
}
"
)
return
FusedMoEQuantConfig
(
_a1
=
FusedMoEQuantDesc
(
dtype
=
None
,
shape
=
a_shape
),
_a2
=
FusedMoEQuantDesc
(
dtype
=
None
,
shape
=
a_shape
),
_w1
=
FusedMoEQuantDesc
(
weight_dtype
,
w_shape
,
w1_scale
,
None
,
w1_zp
,
w1_bias
),
_w2
=
FusedMoEQuantDesc
(
weight_dtype
,
w_shape
,
w2_scale
,
None
,
w2_zp
,
w2_bias
),
)
def
biased_moe_quant_config
(
w1_bias
:
torch
.
Tensor
|
None
,
w2_bias
:
torch
.
Tensor
|
None
,
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
a3f8d5dd
...
...
@@ -460,7 +460,6 @@ def cutlass_moe_fp8(
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
parallel_config
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
...
...
@@ -538,7 +537,6 @@ def cutlass_moe_fp8(
c_strides2
=
c_strides2
,
quant_config
=
quant_config
,
),
parallel_config
=
parallel_config
,
)
return
fn
(
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
a3f8d5dd
...
...
@@ -293,7 +293,7 @@ def deep_gemm_moe_fp8(
expert_map
:
torch
.
Tensor
|
None
=
None
,
a1_scale
:
torch
.
Tensor
|
None
=
None
,
a2_scale
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
View file @
a3f8d5dd
...
...
@@ -84,10 +84,16 @@ def _fwd_kernel_ep_scatter_1(
m_indices_start_ptr
=
m_indices
+
cur_expert_start
off_expert
=
tl
.
arange
(
0
,
BLOCK_E
)
# any rows in the per-expert aligned region that do not correspond to
# real tokens are left untouched here and should remain initialized to
# -1 so DeepGEMM can skip them
for
start_m
in
tl
.
range
(
0
,
cur_expert_token_num
,
BLOCK_E
,
num_stages
=
4
):
offs
=
start_m
+
off_expert
mask
=
offs
<
cur_expert_token_num
tl
.
store
(
m_indices_start_ptr
+
start_m
+
off_expert
,
m_indices_start_ptr
+
offs
,
cur_expert
,
mask
=
mask
,
)
...
...
@@ -366,12 +372,17 @@ def deepgemm_moe_permute(
(
M_sum
,
H
//
block_k
),
device
=
device
,
dtype
=
torch
.
float32
)
maybe_has_empty_blocks
=
(
expert_tokens_meta
is
None
)
or
(
expert_tokens_meta
.
expert_num_tokens_cpu
is
None
# DeepGEMM uses negative values in m_indices (here expert_ids) to mark
# completely invalid / padded blocks that should be skipped. We always
# initialize expert_ids to -1 so any row that is not explicitly written
# by the scatter kernel will be treated as invalid and skipped by
# DeepGEMM's scheduler.
expert_ids
=
torch
.
full
(
(
M_sum
,),
fill_value
=-
1
,
device
=
device
,
dtype
=
torch
.
int32
,
)
expert_ids_init
=
torch
.
zeros
if
maybe_has_empty_blocks
else
torch
.
empty
expert_ids
=
expert_ids_init
((
M_sum
),
device
=
device
,
dtype
=
torch
.
int32
)
inv_perm
=
torch
.
empty
(
topk_ids
.
shape
,
device
=
device
,
dtype
=
torch
.
int32
)
expert_num_tokens
=
None
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
a3f8d5dd
...
...
@@ -903,12 +903,11 @@ def get_moe_configs(
# If no optimized configuration is available, we will use the default
# configuration
logger
.
warning
(
(
logger
.
warning_once
(
"Using default MoE config. Performance might be sub-optimal! "
"Config file not found at %s"
),
co
nfig_file_paths
,
"Config file not found at %s"
,
", "
.
join
(
config_file_paths
),
s
co
pe
=
"local"
,
)
return
None
...
...
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
View file @
a3f8d5dd
...
...
@@ -43,11 +43,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
shared_experts
:
torch
.
nn
.
Module
|
None
,
)
->
"FusedMoEModularMethod"
:
parallel_config
=
getattr
(
getattr
(
moe_layer
,
"vllm_config"
,
None
),
"parallel_config"
,
None
,
)
return
FusedMoEModularMethod
(
old_quant_method
,
FusedMoEModularKernel
(
...
...
@@ -55,7 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method
.
select_gemm_impl
(
prepare_finalize
,
moe_layer
),
shared_experts
,
getattr
(
moe_layer
,
"shared_experts_stream"
,
None
),
parallel_config
=
parallel_config
,
moe_
parallel_config
=
moe_layer
.
moe_
parallel_config
,
),
)
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
a3f8d5dd
...
...
@@ -371,7 +371,9 @@ class FusedMoE(CustomOp):
# aux_stream() returns None on non-cuda-alike platforms.
self
.
shared_experts_stream
=
aux_stream
()
if
self
.
shared_experts_stream
is
not
None
:
logger
.
info_once
(
"Enabled separate cuda stream for MoE shared_experts"
)
logger
.
info_once
(
"Enabled separate cuda stream for MoE shared_experts"
,
scope
=
"local"
)
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
...
...
@@ -891,7 +893,7 @@ class FusedMoE(CustomOp):
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We dont need shared_output.record_stream(current_stream())
# NOTE: We don
'
t need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone
.
record_stream
(
self
.
shared_experts_stream
)
...
...
@@ -1222,10 +1224,14 @@ class FusedMoE(CustomOp):
if
full_load
:
shard_dim
+=
1
# Materialize GGUF UninitializedParameter
# Materialize GGUF UninitializedParameter
accounting merged weights
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
# To materialize a tensor, we must have full shape including
# number of experts, making this portion to require `full_load`.
assert
full_load
final_shape
=
list
(
loaded_weight
.
shape
)
if
shard_id
in
[
"w1"
,
"w3"
]:
# w1 and w3 are merged per expert.
if
shard_id
in
{
"w1"
,
"w3"
}:
final_shape
[
1
]
*=
2
final_shape
[
shard_dim
]
=
final_shape
[
shard_dim
]
//
self
.
tp_size
param
.
materialize
(
final_shape
,
dtype
=
loaded_weight
.
dtype
)
...
...
@@ -1578,6 +1584,14 @@ class FusedMoE(CustomOp):
f
"EPLB is not supported for
{
self
.
quant_method
.
method_name
}
."
)
def
valid_grouping
()
->
bool
:
# Check if num_experts is greater than num_expert_group
# and is divisible by num_expert_group
num_experts
=
router_logits
.
shape
[
-
1
]
if
num_experts
<=
self
.
num_expert_group
:
return
False
return
num_experts
%
self
.
num_expert_group
==
0
indices_type
=
self
.
quant_method
.
topk_indices_dtype
# Check if we should use a routing simulation strategy
...
...
@@ -1592,7 +1606,7 @@ class FusedMoE(CustomOp):
)
# DeepSeekv2 uses grouped_top_k
elif
self
.
use_grouped_topk
:
elif
self
.
use_grouped_topk
and
valid_grouping
()
:
assert
self
.
topk_group
is
not
None
assert
self
.
num_expert_group
is
not
None
if
rocm_aiter_ops
.
is_fused_moe_enabled
():
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
a3f8d5dd
...
...
@@ -10,10 +10,12 @@ from typing import final
import
torch
import
vllm.envs
as
envs
from
vllm.config
import
ParallelConfig
,
get_current_vllm_config
from
vllm.forward_context
import
get_forward_context
,
is_forward_context_available
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_resize_cache
,
count_expert_num_tokens
,
...
...
@@ -22,12 +24,12 @@ from vllm.model_executor.layers.fused_moe.utils import (
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.worker.ubatching
import
(
dbo_current_ubatch_id
,
dbo_enabled
,
dbo_maybe_run_recv_hook
,
dbo_register_recv_hook
,
dbo_yield
,
)
from
vllm.v1.worker.workspace
import
current_workspace_manager
logger
=
init_logger
(
__name__
)
...
...
@@ -661,25 +663,6 @@ def _slice_scales(
return
None
class
SharedResizableBuffer
:
def
__init__
(
self
):
self
.
buffer
=
None
def
get
(
self
,
shape
:
tuple
[
int
,
...],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
assert
shape
!=
()
shape_numel
=
prod
(
shape
)
if
(
self
.
buffer
is
None
or
self
.
buffer
.
numel
()
<
shape_numel
or
self
.
buffer
.
device
!=
device
or
self
.
buffer
.
dtype
!=
dtype
):
self
.
buffer
=
torch
.
empty
(
shape_numel
,
device
=
device
,
dtype
=
dtype
)
return
self
.
buffer
[:
shape_numel
].
view
(
*
shape
)
@
final
class
FusedMoEModularKernel
(
torch
.
nn
.
Module
):
"""
...
...
@@ -694,29 +677,13 @@ class FusedMoEModularKernel(torch.nn.Module):
objects.
"""
class
SharedBuffers
:
def
__init__
(
self
)
->
None
:
self
.
fused_out
=
SharedResizableBuffer
()
self
.
workspace13
=
SharedResizableBuffer
()
self
.
workspace2
=
SharedResizableBuffer
()
# Persistent buffers that are shared across `FusedMoEModularKernel`
# instances (layers), to save memory and allocattions.
#
# We have two sets of buffers to support dual batch overlap (DBO) where each
# microbatch (ubatch) should use its own set of buffers to avoid
# cross-ubatch contimination.
# NOTE that memory is lazily allocated for these buffers, meaning that if
# DBO isn't being used, the second SharedBuffers will be empty.
shared_buffers
:
list
[
SharedBuffers
]
=
[
SharedBuffers
(),
SharedBuffers
()]
def
__init__
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
fused_experts
:
FusedMoEPermuteExpertsUnpermute
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
shared_experts_stream
:
torch
.
cuda
.
Stream
|
None
=
None
,
parallel_config
:
ParallelConfig
|
None
=
None
,
moe_
parallel_config
:
FusedMoE
ParallelConfig
|
None
=
None
,
):
super
().
__init__
()
self
.
prepare_finalize
=
prepare_finalize
...
...
@@ -724,12 +691,15 @@ class FusedMoEModularKernel(torch.nn.Module):
self
.
shared_experts
=
shared_experts
self
.
shared_experts_stream
=
shared_experts_stream
# cache whether this worker is using DP+EP
if
parallel_config
is
None
:
parallel_config
=
get_current_vllm_config
().
parallel_config
# prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests).
# if not provided, assume this kernel is
# running in a non-DP+EP context
self
.
moe_parallel_config
:
FusedMoEParallelConfig
|
None
=
moe_parallel_config
self
.
is_dp_ep
=
(
parallel_config
.
data_parallel_size
>
1
and
parallel_config
.
enable_expert_parallel
moe_parallel_config
is
not
None
and
moe_parallel_config
.
dp_size
>
1
and
moe_parallel_config
.
use_ep
)
self
.
_post_init_setup
()
...
...
@@ -806,10 +776,6 @@ class FusedMoEModularKernel(torch.nn.Module):
assert
M_full
>
0
and
M_chunk
>
0
num_chunks
,
_
=
self
.
_chunk_info
(
M_full
)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx
=
dbo_current_ubatch_id
()
buffers
=
self
.
shared_buffers
[
ubatch_idx
]
workspace_dtype
=
self
.
fused_experts
.
workspace_dtype
(
out_dtype
)
# Force worst-case allocation in profiling run for
...
...
@@ -832,14 +798,11 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta
,
)
)
buffers
.
workspace13
.
get
(
max_workspace_13
,
device
=
device
,
dtype
=
workspace_dtype
)
buffers
.
workspace2
.
get
(
max_workspace_2
,
device
=
device
,
dtype
=
workspace_dtype
)
buffers
.
fused_out
.
get
(
max_fused_out_shape
,
device
=
device
,
dtype
=
workspace_dtype
current_workspace_manager
().
get_simultaneous
(
(
max_workspace_13
,
workspace_dtype
),
(
max_workspace_2
,
workspace_dtype
),
(
max_fused_out_shape
,
out_dtype
),
)
# Get intermediate workspace shapes based off the chunked M size.
...
...
@@ -866,22 +829,23 @@ class FusedMoEModularKernel(torch.nn.Module):
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13
=
buffers
.
workspace13
.
get
(
workspace13_shape
,
device
=
device
,
dtype
=
workspace_dtype
)
workspace2
=
buffers
.
workspace2
.
get
(
workspace2_shape
,
device
=
device
,
dtype
=
workspace_dtype
)
# Construct the entire output that can then be processed in chunks.
# Reuse workspace13 for the output in the non-chunked case as long
# as it is large enough. This will not always be the case for standard
# format experts and with experts that have empty workspaces.
if
num_chunks
==
1
and
prod
(
workspace13_shape
)
>=
prod
(
fused_out_shape
):
workspace13
,
workspace2
=
current_workspace_manager
().
get_simultaneous
(
(
workspace13_shape
,
workspace_dtype
),
(
workspace2_shape
,
workspace_dtype
),
)
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
else
:
fused_out
=
buffers
.
fused_out
.
get
(
fused_out_shape
,
device
=
device
,
dtype
=
out_dtype
workspace13
,
workspace2
,
fused_out
=
(
current_workspace_manager
().
get_simultaneous
(
(
workspace13_shape
,
workspace_dtype
),
(
workspace2_shape
,
workspace_dtype
),
(
fused_out_shape
,
out_dtype
),
)
)
return
workspace13
,
workspace2
,
fused_out
...
...
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
View file @
a3f8d5dd
...
...
@@ -30,8 +30,8 @@ class SharedFusedMoE(FusedMoE):
# Disable shared expert overlap if:
# - we are using eplb, because of correctness issues
# - we are using flashinfer with DP, since there nothin
t
to gain
# - we are using marlin k
j
ernels
# - we are using flashinfer with DP, since there nothin
g
to gain
# - we are using marlin kernels
self
.
use_overlapped
=
(
use_overlapped
and
not
(
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
a3f8d5dd
...
...
@@ -470,6 +470,11 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
}
)
intermediate_size_full
=
extra_weight_attrs
.
pop
(
"intermediate_size_full"
,
intermediate_size_per_partition
)
self
.
is_k_full
=
intermediate_size_per_partition
==
intermediate_size_full
w13_qweight
=
Parameter
(
torch
.
empty
(
num_experts
,
...
...
@@ -597,6 +602,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
)
replace_parameter
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# The modular kernel expects w13_weight and w2_weight,
# but AWQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer
.
w13_weight
=
layer
.
w13_qweight
# Alias for modular kernel
layer
.
w2_weight
=
layer
.
w2_qweight
# Why does this take the intermediate size for size_k?
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
...
...
@@ -661,7 +673,88 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
return
None
from
vllm.model_executor.layers.fused_moe.config
import
(
awq_marlin_moe_quant_config
,
)
return
awq_marlin_moe_quant_config
(
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
weight_bits
=
self
.
quant_config
.
weight_bits
,
group_size
=
self
.
quant_config
.
group_size
,
w1_zp
=
getattr
(
layer
,
"w13_qzeros"
,
None
)
if
self
.
quant_config
.
zero_point
else
None
,
w2_zp
=
getattr
(
layer
,
"w2_qzeros"
,
None
)
if
self
.
quant_config
.
zero_point
else
None
,
w1_bias
=
getattr
(
layer
,
"w13_bias"
,
None
),
w2_bias
=
getattr
(
layer
,
"w2_bias"
,
None
),
)
def
select_gemm_impl
(
self
,
prepare_finalize
,
layer
:
torch
.
nn
.
Module
,
):
"""
Select the GEMM implementation for AWQ-Marlin MoE.
Returns MarlinExperts configured for AWQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, AWQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, AWQ's own apply() method works fine and is more efficient
if
not
self
.
moe
.
is_lora_enabled
:
raise
NotImplementedError
(
"AWQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
BatchedMarlinExperts
,
MarlinExperts
,
)
# Ensure quant config is initialized
assert
self
.
moe_quant_config
is
not
None
,
(
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx
=
getattr
(
layer
,
"w13_g_idx"
,
None
)
w2_g_idx
=
getattr
(
layer
,
"w2_g_idx"
,
None
)
w13_g_idx_sort_indices
=
getattr
(
layer
,
"w13_g_idx_sort_indices"
,
None
)
w2_g_idx_sort_indices
=
getattr
(
layer
,
"w2_g_idx_sort_indices"
,
None
)
# Check if using batched expert format (for Expert Parallelism)
if
(
prepare_finalize
.
activation_format
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens_per_rank
is
not
None
return
BatchedMarlinExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
w13_g_idx_sort_indices
=
w13_g_idx_sort_indices
,
w2_g_idx_sort_indices
=
w2_g_idx_sort_indices
,
is_k_full
=
self
.
is_k_full
,
)
else
:
# Standard Marlin experts for AWQ
return
MarlinExperts
(
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
w13_g_idx_sort_indices
=
w13_g_idx_sort_indices
,
w2_g_idx_sort_indices
=
w2_g_idx_sort_indices
,
is_k_full
=
self
.
is_k_full
,
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
a3f8d5dd
...
...
@@ -1266,9 +1266,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
parallel_config
=
getattr
(
getattr
(
layer
,
"vllm_config"
,
None
),
"parallel_config"
,
None
),
)
else
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
View file @
a3f8d5dd
...
...
@@ -28,7 +28,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# dont restrict as emulations
# don
'
t restrict as emulations
return
80
def
create_weights
(
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
a3f8d5dd
...
...
@@ -137,7 +137,7 @@ def get_fp8_moe_backend(
if
(
current_platform
.
is_cuda
()
and
(
current_platform
.
is_device_capability
(
100
)
current_platform
.
is_device_capability
_family
(
100
)
or
current_platform
.
is_device_capability
(
90
)
)
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
...
...
@@ -148,7 +148,7 @@ def get_fp8_moe_backend(
logger
.
info_once
(
"Using FlashInfer FP8 MoE TRTLLM backend for SM100"
)
return
Fp8MoeBackend
.
FLASHINFER_TRTLLM
else
:
if
block_quant
and
current_platform
.
is_device_capability
(
100
):
if
block_quant
and
current_platform
.
is_device_capability
_family
(
100
):
raise
ValueError
(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
...
...
@@ -193,7 +193,7 @@ def get_fp8_moe_backend(
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
if
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
)
and
current_platform
.
is_device_capability
_family
(
100
)
and
block_quant
):
logger
.
info_once
(
...
...
@@ -332,7 +332,10 @@ class Fp8Config(QuantizationConfig):
fused_mapping
=
self
.
packed_modules_mapping
,
):
return
UnquantizedFusedMoEMethod
(
layer
.
moe_config
)
if
self
.
is_checkpoint_fp8_serialized
:
moe_quant_method
=
Fp8MoEMethod
(
self
,
layer
)
else
:
moe_quant_method
=
Fp8OnlineMoEMethod
(
self
,
layer
)
moe_quant_method
.
marlin_input_dtype
=
get_marlin_input_dtype
(
prefix
)
return
moe_quant_method
elif
isinstance
(
layer
,
Attention
):
...
...
@@ -745,8 +748,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
assert
self
.
quant_config
.
is_checkpoint_fp8_serialized
params_dtype
=
torch
.
float8_e4m3fn
if
self
.
block_quant
:
assert
self
.
weight_block_size
is
not
None
layer
.
weight_block_size
=
self
.
weight_block_size
...
...
@@ -773,41 +777,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
f
"weight quantization block_k =
{
block_k
}
."
)
# if we are doing online quantization, patch the weight
# loaded to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs
=
extra_weight_attrs
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# load the current weight chunk
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
# add a counter to track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
layer
.
_loaded_numel
+=
loaded_weight
.
numel
()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
w13_weight
.
numel
()
+
layer
.
w2_weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer
.
_already_called_process_weights_after_loading
=
True
return
res
new_extra_weight_attrs
[
"weight_loader"
]
=
patched_weight_loader
extra_weight_attrs
=
new_extra_weight_attrs
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
...
...
@@ -875,21 +844,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
block_quant
else
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
...
...
@@ -986,45 +945,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_weight_scale_inv
=
Parameter
(
dg_w2_weight_scale_inv
,
requires_grad
=
False
)
# If checkpoint is fp16, quantize in place.
elif
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
replace_parameter
(
layer
,
"w13_weight_scale"
,
torch
.
ones
(
layer
.
local_num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
,
),
)
for
expert
in
range
(
layer
.
local_num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
replace_parameter
(
layer
,
"w13_weight"
,
w13_weight
)
replace_parameter
(
layer
,
"w2_weight"
,
w2_weight
)
if
self
.
rocm_aiter_moe_enabled
:
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
rocm_aiter_ops
.
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
replace_parameter
(
layer
,
"w13_weight"
,
shuffled_w13
)
replace_parameter
(
layer
,
"w2_weight"
,
shuffled_w2
)
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
...
...
@@ -1387,6 +1307,151 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return
result
class
Fp8OnlineMoEMethod
(
Fp8MoEMethod
):
"""MoE method for online FP8 quantization.
Supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
,
layer
:
torch
.
nn
.
Module
):
super
().
__init__
(
quant_config
,
layer
)
assert
not
quant_config
.
is_checkpoint_fp8_serialized
assert
quant_config
.
activation_scheme
==
"dynamic"
assert
quant_config
.
weight_block_size
is
None
assert
self
.
flashinfer_moe_backend
is
None
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs
=
extra_weight_attrs
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# load the current weight chunk
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
# add a counter to track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
layer
.
_loaded_numel
+=
loaded_weight
.
numel
()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
w13_weight
.
numel
()
+
layer
.
w2_weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer
.
_already_called_process_weights_after_loading
=
True
return
res
new_extra_weight_attrs
[
"weight_loader"
]
=
patched_weight_loader
extra_weight_attrs
=
new_extra_weight_attrs
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
self
.
rocm_aiter_moe_enabled
=
False
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# Lazy import to avoid importing triton too early.
self
.
rocm_aiter_moe_enabled
=
rocm_aiter_ops
.
is_fused_moe_enabled
()
# If checkpoint is fp16, quantize in place.
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
for
expert
in
range
(
layer
.
local_num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
replace_parameter
(
layer
,
"w13_weight"
,
w13_weight
)
replace_parameter
(
layer
,
"w2_weight"
,
w2_weight
)
# Reshuffle weights for AITER if needed.
if
self
.
rocm_aiter_moe_enabled
:
shuffled_w13
,
shuffled_w2
=
rocm_aiter_ops
.
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
replace_parameter
(
layer
,
"w13_weight"
,
shuffled_w13
)
replace_parameter
(
layer
,
"w2_weight"
,
shuffled_w2
)
# Rushuffle weights for MARLIN if needed.
if
self
.
use_marlin
:
prepare_moe_fp8_layer_for_marlin
(
layer
,
False
,
input_dtype
=
self
.
marlin_input_dtype
)
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
a3f8d5dd
...
...
@@ -33,6 +33,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from
vllm.model_executor.models.utils
import
WeightsMapper
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
...
...
@@ -52,6 +53,11 @@ class GGUFConfig(QuantizationConfig):
return
"gguf"
def
get_supported_act_dtypes
(
self
)
->
list
[
torch
.
dtype
]:
# GGUF dequantization kernels use half precision (fp16) internally.
# bfloat16 has precision issues on Blackwell devices.
if
current_platform
.
has_device_capability
(
100
):
logger
.
warning_once
(
"GGUF has precision issues with bfloat16 on Blackwell."
)
return
[
torch
.
half
,
torch
.
float32
]
return
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float32
]
@
classmethod
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
a3f8d5dd
...
...
@@ -732,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
is_a_8bit
=
is_a_8bit
,
)
replace_parameter
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer
.
w13_weight
=
layer
.
w13_qweight
# Alias for modular kernel
layer
.
w2_weight
=
layer
.
w2_qweight
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
...
...
@@ -782,7 +790,107 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
return
None
from
vllm.model_executor.layers.fused_moe.config
import
(
gptq_marlin_moe_quant_config
,
)
return
gptq_marlin_moe_quant_config
(
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
weight_bits
=
self
.
quant_config
.
weight_bits
,
group_size
=
self
.
quant_config
.
group_size
,
w1_zp
=
getattr
(
layer
,
"w13_qzeros"
,
None
)
if
not
self
.
quant_config
.
is_sym
else
None
,
w2_zp
=
getattr
(
layer
,
"w2_qzeros"
,
None
)
if
not
self
.
quant_config
.
is_sym
else
None
,
w1_bias
=
getattr
(
layer
,
"w13_bias"
,
None
),
w2_bias
=
getattr
(
layer
,
"w2_bias"
,
None
),
)
def
select_gemm_impl
(
self
,
prepare_finalize
,
layer
:
torch
.
nn
.
Module
,
):
"""
Select the GEMM implementation for GPTQ-Marlin MoE.
Returns MarlinExperts configured for GPTQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, GPTQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, GPTQ's own apply() method works fine and is more efficient
if
not
self
.
moe
.
is_lora_enabled
:
raise
NotImplementedError
(
"GPTQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
# The modular marlin kernels do not support 8-bit weights.
if
self
.
quant_config
.
weight_bits
==
8
:
raise
NotImplementedError
(
"GPTQ-Marlin kernel does not support 8-bit weights."
)
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
BatchedMarlinExperts
,
MarlinExperts
,
)
# Ensure quant config is initialized
assert
self
.
moe_quant_config
is
not
None
,
(
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx
=
(
getattr
(
layer
,
"w13_g_idx"
,
None
)
if
self
.
quant_config
.
desc_act
else
None
)
w2_g_idx
=
(
getattr
(
layer
,
"w2_g_idx"
,
None
)
if
self
.
quant_config
.
desc_act
else
None
)
w13_g_idx_sort_indices
=
(
getattr
(
layer
,
"w13_g_idx_sort_indices"
,
None
)
if
self
.
quant_config
.
desc_act
else
None
)
w2_g_idx_sort_indices
=
(
getattr
(
layer
,
"w2_g_idx_sort_indices"
,
None
)
if
self
.
quant_config
.
desc_act
else
None
)
# Check if using batched expert format (for Expert Parallelism)
if
(
prepare_finalize
.
activation_format
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens_per_rank
is
not
None
return
BatchedMarlinExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
w13_g_idx_sort_indices
=
w13_g_idx_sort_indices
,
w2_g_idx_sort_indices
=
w2_g_idx_sort_indices
,
is_k_full
=
self
.
is_k_full
,
)
else
:
# Standard Marlin experts for GPTQ
return
MarlinExperts
(
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
w13_g_idx_sort_indices
=
w13_g_idx_sort_indices
,
w2_g_idx_sort_indices
=
w2_g_idx_sort_indices
,
is_k_full
=
self
.
is_k_full
,
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
View file @
a3f8d5dd
...
...
@@ -17,7 +17,9 @@ class ScaledMMLinearLayerConfig:
class
ScaledMMLinearKernel
(
ABC
):
@
classmethod
@
abstractmethod
def
get_min_capability
(
cls
)
->
int
:
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
raise
NotImplementedError
@
classmethod
...
...
@@ -35,6 +37,7 @@ class ScaledMMLinearKernel(ABC):
azp_adj_param_name
:
str
,
)
->
None
:
assert
self
.
can_implement
(
c
)
assert
self
.
is_supported
()
self
.
config
=
c
self
.
w_q_name
=
w_q_param_name
self
.
w_s_name
=
w_s_param_name
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
View file @
a3f8d5dd
...
...
@@ -27,7 +27,7 @@ from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
ScaledMMLinearKernel
]]]
=
{
PlatformEnum
.
CPU
:
[
CPUScaledMMLinearKernel
],
PlatformEnum
.
CUDA
:
[
CutlassScaledMMLinearKernel
],
PlatformEnum
.
CUDA
:
[
CutlassScaledMMLinearKernel
,
TritonScaledMMLinearKernel
],
PlatformEnum
.
ROCM
:
[
AiterScaledMMLinearKernel
,
TritonScaledMMLinearKernel
],
PlatformEnum
.
TPU
:
[
XLAScaledMMLinearKernel
],
}
...
...
@@ -55,41 +55,25 @@ def choose_scaled_mm_linear_kernel(
type[ScaledMMLinearKernel]: Chosen kernel.
"""
if
compute_capability
is
None
:
_cc
=
current_platform
.
get_device_capability
()
if
_cc
is
not
None
:
compute_capability
=
_cc
[
0
]
*
10
+
_cc
[
1
]
failure_reasons
=
[]
for
kernel
in
_POSSIBLE_KERNELS
[
current_platform
.
_enum
]:
if
kernel
.
__name__
in
os
.
environ
.
get
(
"VLLM_DISABLED_KERNELS"
,
""
).
split
(
","
):
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
disabled by environment variable"
)
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
: disabled by env var"
)
continue
# If the current platform uses compute_capability,
# make sure the kernel supports the compute cability.
if
compute_capability
is
not
None
:
kernel_min_capability
=
kernel
.
get_min_capability
()
if
(
kernel_min_capability
is
not
None
and
kernel_min_capability
>
compute_capability
):
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
requires capability "
f
"
{
kernel_min_capability
}
, current compute capability "
f
"is
{
compute_capability
}
"
)
# make sure the kernel supports the compute capability.
is_supported
,
reason
=
kernel
.
is_supported
(
compute_capability
)
if
not
is_supported
:
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
:
{
reason
}
"
)
continue
can_implement
,
reason
=
kernel
.
can_implement
(
config
)
if
not
can_implement
:
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
:
{
reason
}
"
)
continue
can_implement
,
failure_reason
=
kernel
.
can_implement
(
config
)
if
can_implement
:
return
kernel
else
:
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
cannot implement due to:
{
failure_reason
}
"
)
raise
ValueError
(
"Failed to find a kernel that can implement the "
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
View file @
a3f8d5dd
...
...
@@ -14,17 +14,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
class
AiterScaledMMLinearKernel
(
CutlassScaledMMLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
90
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_rocm
():
return
(
False
,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+
"currently supported on non-ROCm platform."
,
)
if
compute_capability
is
None
:
_cc
=
current_platform
.
get_device_capability
()
if
_cc
is
not
None
:
compute_capability
=
_cc
.
major
*
10
+
_cc
.
minor
if
compute_capability
is
not
None
and
compute_capability
<
90
:
return
False
,
f
"requires capability 90, got
{
compute_capability
}
"
try
:
import
aiter
# noqa: F401 # deliberately attempt to import aiter
...
...
@@ -34,8 +38,8 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
"AiterScaledMMLinearKernel requires `aiter` which is not "
+
"installed on ROCm."
,
)
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if
not
(
rocm_aiter_ops
.
is_linear_enabled
()
)
:
if
not
rocm_aiter_ops
.
is_linear_enabled
():
return
(
False
,
"AiterScaledMMLinearKernel is disabled. "
...
...
@@ -44,6 +48,10 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
+
"`VLLM_ROCM_USE_AITER_LINEAR` default is True."
,
)
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
if
not
c
.
input_symmetric
:
return
(
False
,
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment