Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
425eb81e
Commit
425eb81e
authored
Feb 24, 2026
by
jujl1
Browse files
Merge branch 'v0.15.1-dev' into 'v0.15.1-dev-w4a8+pp_balance'
# Conflicts: # vllm/envs.py
parents
7b2122d9
358bc2c5
Changes
36
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1164 additions
and
127 deletions
+1164
-127
csrc/sampler.cu
csrc/sampler.cu
+18
-5
docs/models/supported_models.md
docs/models/supported_models.md
+10
-8
tests/models/registry.py
tests/models/registry.py
+21
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+62
-0
vllm/config/model.py
vllm/config/model.py
+6
-4
vllm/config/speculative.py
vllm/config/speculative.py
+11
-0
vllm/envs.py
vllm/envs.py
+37
-1
vllm/model_executor/layers/fused_moe/fuse_moe_w16a16_marlin.py
...model_executor/layers/fused_moe/fuse_moe_w16a16_marlin.py
+5
-13
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+110
-9
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+87
-3
vllm/model_executor/layers/fused_moe/router_capture.py
vllm/model_executor/layers/fused_moe/router_capture.py
+360
-0
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+101
-2
vllm/model_executor/layers/mamba/abstract.py
vllm/model_executor/layers/mamba/abstract.py
+2
-1
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+50
-10
vllm/model_executor/layers/mamba/mamba_utils.py
vllm/model_executor/layers/mamba/mamba_utils.py
+6
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+16
-10
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+204
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+9
-4
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+15
-9
vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py
...executor/layers/quantization/kernels/scaled_mm/pytorch.py
+34
-44
No files found.
csrc/sampler.cu
View file @
425eb81e
...
...
@@ -215,7 +215,11 @@ __device__ bool processHistogramStep(
// Compute the prefix sum.
int
prefixSum
{
0
},
totalSum
{
0
};
#ifndef USE_ROCM
using
Scan
=
cub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#else:
using
Scan
=
hipcub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#endif
Scan
(
smemFinal
.
histo
.
scan
).
ExclusiveSum
(
binCount
,
prefixSum
,
totalSum
);
// Update the histogram with the prefix sums.
...
...
@@ -334,13 +338,22 @@ static __device__ void topKPerRowJob(const int* indices, const float* logits,
static
constexpr
int
kNumFinalItemsPerThread
=
kNumFinalItems
/
kNumThreadsPerBlock
;
// The class to sort the elements during the final pass.
#ifndef USE_ROCM
using
FinalSort
=
cub
::
BlockRadixSort
<
float
,
kNumThreadsPerBlock
,
kNumFinalItemsPerThread
,
int
>
;
#else
using
FinalSort
=
hipcub
::
BlockRadixSort
<
float
,
kNumThreadsPerBlock
,
kNumFinalItemsPerThread
,
int
>
;
#endif
using
FinalSortTempStorage
=
std
::
conditional_t
<
useRadixSort
,
typename
FinalSort
::
TempStorage
,
int
>
;
// The class to compute the inclusive prefix-sum over the histogram.
#ifndef USE_ROCM
using
Scan
=
cub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#else
using
Scan
=
hipcub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#endif
// The structure to store the final items (for the final pass).
struct
FinalItems
{
// Shared memory to store the indices for the final pass.
...
...
docs/models/supported_models.md
View file @
425eb81e
...
...
@@ -717,6 +717,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
`Qwen2VLForConditionalGeneration`
| QVQ, Qwen2-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/QVQ-72B-Preview`
,
`Qwen/Qwen2-VL-7B-Instruct`
,
`Qwen/Qwen2-VL-72B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen2_5_VLForConditionalGeneration`
| Qwen2.5-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen2.5-VL-3B-Instruct`
,
`Qwen/Qwen2.5-VL-72B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen2_5OmniThinkerForConditionalGeneration`
| Qwen2.5-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen2.5-Omni-3B`
,
`Qwen/Qwen2.5-Omni-7B`
| ✅︎ | ✅︎ |
|
`Qwen3_5ForConditionalGeneration`
| Qwen3.5 | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3.5-9B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3_5MoeForConditionalGeneration`
| Qwen3.5-MOE | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3.5-35B-A3B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3VLForConditionalGeneration`
| Qwen3-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-4B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3VLMoeForConditionalGeneration`
| Qwen3-VL-MOE | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-30B-A3B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3OmniMoeThinkerForConditionalGeneration`
| Qwen3-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen3-Omni-30B-A3B-Instruct`
,
`Qwen/Qwen3-Omni-30B-A3B-Thinking`
| ✅︎ | ✅︎ |
...
...
tests/models/registry.py
View file @
425eb81e
...
...
@@ -943,6 +943,26 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_model_len
=
4096
,
min_transformers_version
=
"4.57"
,
),
"Qwen3_5ForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-9B-Instruct"
),
max_model_len
=
4096
,
min_transformers_version
=
"5.1.0"
,
),
"Qwen3_5MoeForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-35B-A3B-Instruct"
),
max_model_len
=
4096
,
min_transformers_version
=
"5.1.0"
,
),
"Qwen3_5MTP"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-9B-Instruct"
),
speculative_model
=
"Qwen/Qwen3.5-9B-Instruct"
,
min_transformers_version
=
"5.1.0"
,
),
"Qwen3_5MoeMTP"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-35B-A3B-Instruct"
),
speculative_model
=
"Qwen/Qwen3.5-35B-A3B-Instruct"
,
min_transformers_version
=
"5.1.0"
,
),
"Qwen3OmniMoeForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
),
max_model_len
=
4096
,
...
...
vllm/_custom_ops.py
View file @
425eb81e
...
...
@@ -19,6 +19,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
from
lmslim.layers.gemm.fp8_utils
import
per_token_quant_fp8
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
...
...
@@ -1878,6 +1879,67 @@ def scaled_fp4_experts_quant(
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scales
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
group_shape
:
Optional
[
tuple
[
int
,
int
]]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
current_platform
.
fp8_dtype
()
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
if
output
is
None
:
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
else
:
assert
num_token_padding
is
None
,
\
"padding not supported if output passed in"
assert
output
.
dtype
==
out_dtype
if
scale
is
None
:
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output
,
scale
=
per_token_quant_fp8
(
input
.
contiguous
())
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
def
silu_and_mul_scaled_fp4_experts_quant
(
input_tensor
:
torch
.
Tensor
,
...
...
vllm/config/model.py
View file @
425eb81e
...
...
@@ -1250,7 +1250,9 @@ class ModelConfig:
return
sum
(
t
==
1
for
t
in
attn_type_list
[
start
:
end
])
# Hybrid model Qwen3Next
layer_types_value
=
getattr
(
self
.
hf_config
,
"layer_types"
,
None
)
# layer_types_value = getattr(self.hf_config, "layer_types", None)
# Hybrid model Qwen3Next Qwen3.5 Series
layer_types_value
=
getattr
(
self
.
hf_text_config
,
"layer_types"
,
None
)
if
layer_types_value
is
not
None
:
if
block_type
==
"attention"
:
return
sum
(
...
...
vllm/config/speculative.py
View file @
425eb81e
...
...
@@ -37,6 +37,7 @@ MTPModelTypes = Literal[
"ernie_mtp"
,
"exaone_moe_mtp"
,
"qwen3_next_mtp"
,
"qwen3_5_mtp"
,
"longcat_flash_mtp"
,
"mtp"
,
"pangu_ultra_moe_mtp"
,
...
...
@@ -246,6 +247,16 @@ class SpeculativeConfig:
{
"n_predict"
:
n_predict
,
"architectures"
:
[
"ExaoneMoeMTP"
]}
)
if
hf_config
.
model_type
in
(
"qwen3_5"
,
"qwen3_5_moe"
):
is_moe
=
hf_config
.
model_type
==
"qwen3_5_moe"
hf_config
.
model_type
=
"qwen3_5_mtp"
n_predict
=
getattr
(
hf_config
,
"mtp_num_hidden_layers"
,
None
)
hf_config
.
update
(
{
"n_predict"
:
n_predict
,
"architectures"
:
[
"Qwen3_5MoeMTP"
if
is_moe
else
"Qwen3_5MTP"
],
}
)
if
hf_config
.
model_type
==
"longcat_flash"
:
hf_config
.
model_type
=
"longcat_flash_mtp"
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
1
)
...
...
vllm/envs.py
View file @
425eb81e
...
...
@@ -292,7 +292,15 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_W8A8_BACKEND
:
int
=
3
VLLM_USE_PP_BALANCE
=
True
VLLM_MOE_ROUTER_CAPTURE
:
bool
=
False
VLLM_MOE_ROUTER_CAPTURE_DIR
:
str
=
"/tmp"
VLLM_MOE_ROUTER_CAPTURE_RANK
:
int
=
-
1
VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS
:
int
=
0
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT
:
int
=
-
1
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT
:
int
=
-
1
VLLM_REJECT_SAMPLE_OPT
:
bool
=
False
VLLM_USE_MOE_W16A16_TRITON
:
bool
=
False
VLLM_V1_FAST_TOKEN_ID_COPY
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -1841,14 +1849,42 @@ environment_variables: dict[str, Callable[[], Any]] = {
# blaslt: 3 (default)
# rocblas: others
"VLLM_W8A8_BACKEND"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_W8A8_BACKEND"
,
"3"
)),
# Capture MoE router logits for debugging/analysis.
"VLLM_MOE_ROUTER_CAPTURE"
:
lambda
:
(
os
.
getenv
(
"VLLM_MOE_ROUTER_CAPTURE"
,
"0"
).
lower
()
in
(
"true"
,
"1"
)),
# Output directory for MoE router capture dumps.
"VLLM_MOE_ROUTER_CAPTURE_DIR"
:
lambda
:
os
.
environ
.
get
(
"VLLM_MOE_ROUTER_CAPTURE_DIR"
,
"/tmp"
,
),
# Capture only the specified rank; set to -1 to capture all ranks.
"VLLM_MOE_ROUTER_CAPTURE_RANK"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_MOE_ROUTER_CAPTURE_RANK"
,
"-1"
)),
# Max number of MoE layers to record per process (0 = unlimited).
"VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS"
,
"0"
)),
# Only capture when num_tokens > N (negative disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT"
,
"-1"
)),
# Only capture when num_tokens < N (0 disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT"
,
"-1"
)),
# vllm will use optimized reject sample
"VLLM_REJECT_SAMPLE_OPT"
:
lambda
:
(
os
.
getenv
(
'VLLM_REJECT_SAMPLE_OPT'
,
'True'
).
lower
()
in
(
"true"
,
"1"
)),
# Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MOE_W16A16_TRITON"
,
"0"
).
lower
()
in
(
"true"
,
"1"
)),
#If set to 1/True, enable the V1 fast token-id copy path in InputBatch.
"VLLM_V1_FAST_TOKEN_ID_COPY"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_V1_FAST_TOKEN_ID_COPY"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/fused_moe/fuse_moe_w16a16_marlin.py
View file @
425eb81e
...
...
@@ -397,19 +397,11 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
)
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
top_k_num
,
K
)
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
from
lightop
import
op
as
op
op
.
moe_sum
(
input
=
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
output
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
bias
=
shared_output
[
begin_chunk_idx
:
end_chunk_idx
],
expert_mask
=
None
,
num_local_tokens
=
None
,
factor
=
routed_scaling_factor
)
else
:
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM
:
from
lightop
import
op
as
op
op
.
moe_sum
(
input
=
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
output
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
bias
=
None
,
expert_mask
=
None
,
num_local_tokens
=
None
,
factor
=
1.0
)
elif
envs
.
VLLM_USE_OPT_MOE_SUM
:
moe_reduce_dispatch
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
begin_chunk_idx
,
end_chunk_idx
)
else
:
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
425eb81e
...
...
@@ -1612,6 +1612,7 @@ def fused_experts(
expert_map
:
torch
.
Tensor
|
None
=
None
,
quant_config
:
FusedMoEQuantConfig
|
None
=
None
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
:
if
quant_config
is
None
:
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
...
...
@@ -1705,6 +1706,112 @@ def fused_experts_impl(
)
->
torch
.
Tensor
:
# Check constraints.
num_tokens
=
hidden_states
.
size
(
0
)
top_k_num
=
topk_ids
.
size
(
1
)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
# Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout.
if
not
use_nn_moe
:
K
=
hidden_states
.
size
(
1
)
def
_is_marlin_w16a16_packed
(
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
)
->
bool
:
if
w1
.
dim
()
!=
3
or
w2
.
dim
()
!=
3
:
return
False
if
w1
.
size
(
0
)
!=
w2
.
size
(
0
):
return
False
k_div16
=
w1
.
size
(
1
)
if
k_div16
*
16
!=
K
:
return
False
if
w1
.
size
(
2
)
%
16
!=
0
:
return
False
twoN
=
w1
.
size
(
2
)
//
16
if
twoN
%
2
!=
0
:
return
False
N
=
twoN
//
2
if
w2
.
size
(
2
)
!=
K
*
16
:
return
False
if
w2
.
size
(
1
)
*
16
!=
N
:
return
False
return
True
is_packed
=
(
getattr
(
w1
,
"marlin_w16a16_packed"
,
False
)
or
getattr
(
w2
,
"marlin_w16a16_packed"
,
False
)
or
_is_marlin_w16a16_packed
(
w1
,
w2
)
)
if
is_packed
:
if
envs
.
VLLM_USE_MOE_W16A16_TRITON
:
raise
RuntimeError
(
"VLLM_USE_MOE_W16A16_TRITON=1 forces Triton MoE, but the MoE weights are "
"packed in Marlin W16A16 layout. Please load unpacked weights or set "
"VLLM_USE_MOE_W16A16_TRITON=0."
)
try
:
from
vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin
import
(
fused_experts_impl_w16a16_marlin
,
)
except
Exception
:
fused_experts_impl_w16a16_marlin
=
None
# type: ignore
if
fused_experts_impl_w16a16_marlin
is
None
:
raise
RuntimeError
(
"Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
"Ensure lightop/lmslim is installed and LMSLIM_USE_LIGHTOP=1."
)
if
activation
!=
"silu"
:
raise
RuntimeError
(
"Marlin W16A16 MoE only supports activation='silu'."
)
if
apply_router_weight_on_input
:
raise
RuntimeError
(
"Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
)
if
w1_bias
is
not
None
or
w2_bias
is
not
None
:
raise
RuntimeError
(
"Marlin W16A16 MoE does not support expert biases."
)
E
=
w1
.
size
(
0
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
twoN
=
w1
.
size
(
2
)
//
16
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
cache13
=
get_moe_cache
(
top_k_num
,
twoN
,
K
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
else
:
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
twoN
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
return
fused_experts_impl_w16a16_marlin
(
hidden_states
=
hidden_states
,
w1_marlin
=
w1
,
w2_marlin
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
cache13
=
cache13
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
use_nn_moe
=
False
,
)
if
use_nn_moe
:
E
,
_
,
N
=
w1
.
size
()
else
:
...
...
@@ -1713,18 +1820,12 @@ def fused_experts_impl(
if
global_num_experts
==
-
1
:
global_num_experts
=
E
top_k_num
=
topk_ids
.
size
(
1
)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
cache13
=
get_moe_cache
(
top_k_num
,
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
else
:
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
if
use_int8_w8a8
is
True
:
if
use_int8_w8a8
or
use_fp8_w8a8
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
...
...
@@ -1734,8 +1835,8 @@ def fused_experts_impl(
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_channel_quant
=
per_channel_quant
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
425eb81e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
os
from
collections.abc
import
Callable
,
Iterable
from
contextlib
import
nullcontext
...
...
@@ -72,6 +73,64 @@ from vllm.model_executor.layers.fused_moe.fused_moe import is_power_of_two
logger
=
init_logger
(
__name__
)
_MARLIN_W16A16_MOE_PROBE_BATCH_SIZES
:
tuple
[
int
,
...]
=
(
1
,
128
)
@
functools
.
lru_cache
def
_is_marlin_w16a16_moe_supported
(
E
:
int
,
N
:
int
,
K
:
int
,
top_k
:
int
,
dtype
:
torch
.
dtype
,
)
->
bool
:
"""Return True if lightop reports Marlin W16A16 MoE is supported.
This is a best-effort probe used to decide whether we can safely pre-pack
weights into Marlin layout (which would otherwise prevent fallback).
"""
if
not
(
current_platform
.
is_cuda_alike
()
and
torch
.
cuda
.
is_available
()):
return
False
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
return
False
if
K
%
32
!=
0
or
N
%
16
!=
0
:
return
False
if
E
<=
0
or
N
<=
0
or
K
<=
0
or
top_k
<=
0
:
return
False
try
:
from
lightop
import
get_moe_cuda_marlin_config_w16a16
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
())
arch_name
=
getattr
(
props
,
"gcnArchName"
,
None
)
if
isinstance
(
arch_name
,
str
)
and
arch_name
:
arch_name
=
arch_name
.
split
(
":"
)[
0
]
else
:
arch_name
=
getattr
(
props
,
"name"
,
None
)
if
not
isinstance
(
arch_name
,
str
)
or
not
arch_name
:
return
False
arch_cu
=
props
.
multi_processor_count
twoN
=
2
*
N
for
bs
in
_MARLIN_W16A16_MOE_PROBE_BATCH_SIZES
:
_
,
_
,
status
=
get_moe_cuda_marlin_config_w16a16
(
E
,
bs
,
twoN
,
K
,
K
,
N
,
top_k
,
arch_name
,
arch_cu
,
dtype
,
)
if
not
status
:
return
False
return
True
except
Exception
:
return
False
class
FusedMoeWeightScaleSupported
(
Enum
):
TENSOR
=
"tensor"
...
...
@@ -543,9 +602,9 @@ class FusedMoE(CustomOp):
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
self
.
use_fused_gate
=
envs
.
VLLM_ENABLE_MOE_FUSED_GATE
\
and
self
.
e_score_correction_bias
is
not
None
\
and
num_expert_group
is
not
None
\
and
self
.
global_num_experts
//
num_expert_group
<=
32
\
and
is_power_of_two
(
e_score_correction_bias
.
shape
[
0
])
and
num_expert_group
is
not
None
#
\
#
and self.global_num_experts // num_expert_group <= 32 \
#
and is_power_of_two(e_score_correction_bias.shape[0])
self
.
router
=
create_fused_moe_router
(
top_k
=
top_k
,
...
...
@@ -632,9 +691,28 @@ class FusedMoE(CustomOp):
if
quant_config
is
None
:
# Not considering quant for now, temporarily
self
.
_marlin_w16a16_moe_enabled
=
(
not
envs
.
VLLM_USE_MOE_W16A16_TRITON
and
params_dtype
==
moe_in_dtype
and
not
self
.
moe_config
.
has_bias
and
self
.
activation
==
"silu"
and
not
self
.
apply_router_weight_on_input
and
_is_marlin_w16a16_moe_supported
(
E
=
self
.
local_num_experts
,
N
=
self
.
intermediate_size_per_partition
,
K
=
self
.
hidden_size
,
top_k
=
self
.
top_k
,
dtype
=
moe_in_dtype
,
)
)
self
.
use_nn_moe
=
int
(
os
.
environ
.
get
(
'MOE_NN'
,
1
))
==
1
# Marlin W16A16 MoE requires the non-NN weight layout.
if
self
.
_marlin_w16a16_moe_enabled
:
self
.
use_nn_moe
=
False
else
:
self
.
use_nn_moe
=
False
self
.
_marlin_w16a16_moe_enabled
=
False
moe_quant_params
=
{
"num_experts"
:
self
.
local_num_experts
,
...
...
@@ -671,6 +749,12 @@ class FusedMoE(CustomOp):
# should be safe to swap out the quant_method.
def
maybe_init_modular_kernel
(
self
)
->
None
:
# If this layer is configured for Marlin W16A16 path, we intentionally
# keep the monolithic execution route so runtime can dispatch to
# fused_experts_impl_w16a16_marlin when weights are packed.
if
getattr
(
self
,
"_marlin_w16a16_moe_enabled"
,
False
):
return
self
.
ensure_moe_quant_config_init
()
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
...
...
vllm/model_executor/layers/fused_moe/router_capture.py
0 → 100644
View file @
425eb81e
"""
Utilities for capturing MoE router distributions from real workloads.
This is intentionally lightweight and gated behind env vars so it has zero
runtime impact unless explicitly enabled.
Env vars (defaults from vllm.envs):
- VLLM_MOE_ROUTER_CAPTURE=0/1: enable capture (default: 0).
- VLLM_MOE_ROUTER_CAPTURE_DIR=/path: output directory for per-process dumps
(default: /tmp).
- VLLM_MOE_ROUTER_CAPTURE_RANK=N: only capture on the given torch.distributed
rank (default: -1; set to -1 to capture all ranks).
- VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS=N: max number of layers to record per
process (default: 0; 0 = unlimited).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT=A: only record calls where router_logits
has num_tokens > A (default: -1; <0 = disabled).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT=B: only record calls where router_logits
has num_tokens < B (default: -1; 0 = disabled).
Output format:
- A single `.pt` per captured num_tokens (and per rank if torch.distributed is
initialized).
- Payload includes `layers_by_num_tokens: dict[str, dict[layer_name, layer_state]]`.
- A convenience `layers` field is also included (same as
`layers_by_num_tokens[str(num_tokens)]`) for easy loading.
- For each captured MoE layer, stores a list of 2D tensors
`router_logits_chunks: list[Tensor[num_tokens_i, num_experts]]` on CPU,
typically in fp16 for space efficiency.
"""
from
__future__
import
annotations
import
atexit
import
inspect
import
os
import
socket
import
threading
import
time
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
import
vllm.envs
as
envs
_DEFAULT_SKIP_STACK_FUNCS
=
(
"profile_run"
,
"_dummy_run"
,
"determine_available_memory"
)
@
dataclass
(
frozen
=
True
)
class
RouterCaptureConfig
:
enabled
:
bool
=
False
out_dir
:
str
=
"/tmp"
skip_profile
:
bool
=
True
skip_stack_funcs
:
tuple
[
str
,
...]
=
_DEFAULT_SKIP_STACK_FUNCS
only_rank
:
Optional
[
int
]
=
0
max_layers
:
int
=
0
num_tokens_gt
:
Optional
[
int
]
=
None
num_tokens_lt
:
Optional
[
int
]
=
None
@
staticmethod
def
from_env
()
->
"RouterCaptureConfig"
:
enabled
=
envs
.
VLLM_MOE_ROUTER_CAPTURE
out_dir
=
envs
.
VLLM_MOE_ROUTER_CAPTURE_DIR
skip_profile
=
True
skip_stack_funcs
=
_DEFAULT_SKIP_STACK_FUNCS
only_rank
:
Optional
[
int
]
=
None
if
envs
.
VLLM_MOE_ROUTER_CAPTURE_RANK
>=
0
:
only_rank
=
envs
.
VLLM_MOE_ROUTER_CAPTURE_RANK
max_layers
=
envs
.
VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS
num_tokens_gt_opt
=
(
envs
.
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT
if
envs
.
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT
>=
0
else
None
)
num_tokens_lt_opt
=
(
envs
.
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT
if
envs
.
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT
>
0
else
None
)
# Per-size mode requires an explicit token-count filter to avoid
# unbounded captures by default.
if
num_tokens_gt_opt
is
None
and
num_tokens_lt_opt
is
None
:
enabled
=
False
if
(
num_tokens_gt_opt
is
not
None
and
num_tokens_lt_opt
is
not
None
and
num_tokens_gt_opt
>=
num_tokens_lt_opt
):
enabled
=
False
return
RouterCaptureConfig
(
enabled
=
enabled
,
out_dir
=
out_dir
,
skip_profile
=
skip_profile
,
skip_stack_funcs
=
skip_stack_funcs
,
only_rank
=
only_rank
,
max_layers
=
max_layers
,
num_tokens_gt
=
num_tokens_gt_opt
,
num_tokens_lt
=
num_tokens_lt_opt
)
def
_in_profile_run
(
skip_stack_funcs
:
tuple
[
str
,
...])
->
bool
:
"""
Best-effort detection for vLLM startup profiling/warmup runs.
Startup warmups often execute MoE kernels with synthetic shapes. When
enabled, skip captures from these stacks so the first capture comes from a
real request.
"""
if
not
skip_stack_funcs
:
return
False
frame
=
inspect
.
currentframe
()
try
:
while
frame
is
not
None
:
name
=
frame
.
f_code
.
co_name
if
name
in
skip_stack_funcs
:
return
True
frame
=
frame
.
f_back
finally
:
# Avoid reference cycles.
del
frame
return
False
class
_RouterCapture
:
def
__init__
(
self
,
cfg
:
RouterCaptureConfig
)
->
None
:
self
.
cfg
=
cfg
# Bucket captures by token count.
self
.
_layers_by_num_tokens
:
dict
[
int
,
dict
[
str
,
dict
[
str
,
object
]]]
=
{}
self
.
_layer_names
:
set
[
str
]
=
set
()
self
.
_completed_num_tokens
:
set
[
int
]
=
set
()
self
.
_lock
=
threading
.
Lock
()
self
.
_flush_counter
=
0
self
.
_pid
=
os
.
getpid
()
self
.
_host
=
socket
.
gethostname
()
self
.
_start_time
=
time
.
time
()
os
.
makedirs
(
cfg
.
out_dir
,
exist_ok
=
True
)
atexit
.
register
(
self
.
flush
)
def
_bucket_for_num_tokens
(
self
,
num_tokens
:
int
)
->
Optional
[
int
]:
"""Return the per-size bucket key for this record call, or None if filtered."""
if
self
.
cfg
.
num_tokens_gt
is
None
and
self
.
cfg
.
num_tokens_lt
is
None
:
return
None
if
self
.
cfg
.
num_tokens_gt
is
not
None
:
if
int
(
num_tokens
)
<=
int
(
self
.
cfg
.
num_tokens_gt
):
return
None
if
self
.
cfg
.
num_tokens_lt
is
not
None
:
if
int
(
num_tokens
)
>=
int
(
self
.
cfg
.
num_tokens_lt
):
return
None
bucket_num_tokens
=
int
(
num_tokens
)
if
bucket_num_tokens
!=
0
and
bucket_num_tokens
in
self
.
_completed_num_tokens
:
return
None
return
bucket_num_tokens
def
_snapshot_layers_by_num_tokens
(
self
,
layers_by_num_tokens
:
dict
[
int
,
dict
[
str
,
dict
[
str
,
object
]]],
)
->
dict
[
int
,
dict
[
str
,
dict
[
str
,
object
]]]:
snapshot
:
dict
[
int
,
dict
[
str
,
dict
[
str
,
object
]]]
=
{}
for
num_tokens
,
bucket
in
layers_by_num_tokens
.
items
():
bucket_snapshot
:
dict
[
str
,
dict
[
str
,
object
]]
=
{}
for
layer_name
,
state
in
bucket
.
items
():
chunks
=
state
.
get
(
"router_logits_chunks"
,
[])
bucket_snapshot
[
layer_name
]
=
{
"num_experts"
:
int
(
state
.
get
(
"num_experts"
,
0
)),
"num_tokens"
:
int
(
state
.
get
(
"num_tokens"
,
0
)),
"router_logits_chunks"
:
list
(
chunks
),
}
snapshot
[
int
(
num_tokens
)]
=
bucket_snapshot
return
snapshot
@
torch
.
no_grad
()
def
record
(
self
,
layer_name
:
str
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
)
->
None
:
if
self
.
cfg
.
skip_profile
and
_in_profile_run
(
self
.
cfg
.
skip_stack_funcs
):
return
if
self
.
cfg
.
only_rank
is
not
None
:
rank
=
_get_rank
()
if
rank
is
not
None
and
rank
!=
self
.
cfg
.
only_rank
:
return
if
router_logits
.
dim
()
!=
2
:
return
num_tokens
,
num_experts
=
router_logits
.
shape
if
num_tokens
==
0
or
num_experts
==
0
:
return
bucket_num_tokens
=
self
.
_bucket_for_num_tokens
(
int
(
num_tokens
))
if
bucket_num_tokens
is
None
:
return
# Limit the number of recorded layers to avoid unbounded dumps.
if
layer_name
not
in
self
.
_layer_names
:
if
self
.
cfg
.
max_layers
!=
0
and
len
(
self
.
_layer_names
)
>=
self
.
cfg
.
max_layers
:
return
self
.
_layer_names
.
add
(
layer_name
)
# Store on CPU to avoid consuming GPU memory during long runs.
# fp16 is typically sufficient because we primarily care about
# distribution and relative ordering (top-k), not exact values.
router_logits_cpu
=
router_logits
.
detach
()
if
router_logits_cpu
.
is_cuda
:
router_logits_cpu
=
router_logits_cpu
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
float16
)
else
:
router_logits_cpu
=
router_logits_cpu
.
to
(
dtype
=
torch
.
float16
)
bucket_snapshot
:
Optional
[
dict
[
str
,
dict
[
str
,
object
]]]
=
None
should_flush
=
False
with
self
.
_lock
:
bucket
=
self
.
_layers_by_num_tokens
.
setdefault
(
bucket_num_tokens
,
{})
if
layer_name
in
bucket
:
return
bucket
[
layer_name
]
=
{
"num_experts"
:
int
(
num_experts
),
"num_tokens"
:
int
(
num_tokens
),
"router_logits_chunks"
:
[
router_logits_cpu
],
}
if
self
.
cfg
.
max_layers
!=
0
and
len
(
bucket
)
>=
int
(
self
.
cfg
.
max_layers
):
should_flush
=
True
bucket_snapshot
=
self
.
_snapshot_layers_by_num_tokens
(
{
int
(
bucket_num_tokens
):
bucket
})[
int
(
bucket_num_tokens
)]
self
.
_completed_num_tokens
.
add
(
int
(
bucket_num_tokens
))
self
.
_layers_by_num_tokens
.
pop
(
int
(
bucket_num_tokens
),
None
)
if
should_flush
and
bucket_snapshot
is
not
None
:
self
.
_flush_payload
(
layers_by_num_tokens
=
{
int
(
bucket_num_tokens
):
bucket_snapshot
},
file_tag
=
f
"nt
{
int
(
bucket_num_tokens
)
}
"
,
)
def
_flush_payload
(
self
,
*
,
layers_by_num_tokens
:
dict
[
int
,
dict
[
str
,
dict
[
str
,
object
]]],
file_tag
:
Optional
[
str
]
=
None
,
)
->
Optional
[
str
]:
if
not
self
.
cfg
.
enabled
:
return
None
if
self
.
cfg
.
only_rank
is
not
None
:
rank
=
_get_rank
()
if
rank
is
not
None
and
rank
!=
self
.
cfg
.
only_rank
:
return
None
rank
=
_get_rank
()
now
=
time
.
time
()
ts
=
time
.
strftime
(
"%Y%m%d_%H%M%S"
,
time
.
localtime
(
now
))
ts_us
=
int
(
now
*
1_000_000
)
with
self
.
_lock
:
flush_idx
=
self
.
_flush_counter
self
.
_flush_counter
+=
1
rank_str
=
f
"rank
{
rank
}
"
if
rank
is
not
None
else
"rankNA"
tag
=
f
"
{
file_tag
}
_"
if
file_tag
else
""
out_path
=
os
.
path
.
join
(
self
.
cfg
.
out_dir
,
f
"moe_router_stats_
{
tag
}{
ts_us
}
_
{
self
.
_host
}
_
{
rank_str
}
_pid
{
self
.
_pid
}
_flush
{
flush_idx
}
.pt"
,
)
layers_by_num_tokens_out
:
dict
[
str
,
object
]
=
{}
for
num_tokens
,
bucket
in
layers_by_num_tokens
.
items
():
bucket_out
:
dict
[
str
,
object
]
=
{}
for
layer_name
,
state
in
bucket
.
items
():
bucket_out
[
layer_name
]
=
{
"num_experts"
:
int
(
state
[
"num_experts"
]),
"num_tokens"
:
int
(
state
[
"num_tokens"
]),
"router_logits_chunks"
:
state
[
"router_logits_chunks"
],
# type: ignore[typeddict-item]
}
layers_by_num_tokens_out
[
str
(
int
(
num_tokens
))]
=
bucket_out
payload
:
dict
[
str
,
object
]
=
{
"meta"
:
{
"timestamp"
:
ts
,
"timestamp_us"
:
ts_us
,
"flush_index"
:
int
(
flush_idx
),
"host"
:
self
.
_host
,
"pid"
:
self
.
_pid
,
"rank"
:
rank
,
"wall_time_s"
:
float
(
now
-
self
.
_start_time
),
},
"layers_by_num_tokens"
:
layers_by_num_tokens_out
,
}
# Backward-compatible convenience field when there is a single bucket.
if
len
(
layers_by_num_tokens
)
==
1
:
(
only_bucket_key
,
)
=
layers_by_num_tokens
.
keys
()
payload
[
"layers"
]
=
layers_by_num_tokens_out
[
str
(
int
(
only_bucket_key
))]
try
:
torch
.
save
(
payload
,
out_path
)
except
Exception
:
return
None
return
out_path
def
flush
(
self
)
->
Optional
[
str
]:
with
self
.
_lock
:
if
not
self
.
_layers_by_num_tokens
:
return
None
snapshot
=
self
.
_snapshot_layers_by_num_tokens
(
self
.
_layers_by_num_tokens
)
return
self
.
_flush_payload
(
layers_by_num_tokens
=
snapshot
)
def
reset
(
self
)
->
None
:
with
self
.
_lock
:
self
.
_layers_by_num_tokens
.
clear
()
self
.
_layer_names
.
clear
()
self
.
_completed_num_tokens
.
clear
()
self
.
_start_time
=
time
.
time
()
_CAPTURE
:
Optional
[
_RouterCapture
]
=
None
_CAPTURE_DISABLED
:
bool
=
False
def
_disable_global_capture
()
->
None
:
global
_CAPTURE
,
_CAPTURE_DISABLED
_CAPTURE
=
None
_CAPTURE_DISABLED
=
True
def
_get_rank
()
->
Optional
[
int
]:
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
try
:
return
torch
.
distributed
.
get_rank
()
except
Exception
:
return
None
return
None
def
_get_capture
()
->
Optional
[
_RouterCapture
]:
global
_CAPTURE
,
_CAPTURE_DISABLED
if
_CAPTURE_DISABLED
:
return
None
if
_CAPTURE
is
not
None
:
return
_CAPTURE
cfg
=
RouterCaptureConfig
.
from_env
()
if
not
cfg
.
enabled
:
_disable_global_capture
()
return
None
if
cfg
.
only_rank
is
not
None
:
rank
=
_get_rank
()
if
rank
is
not
None
and
rank
!=
cfg
.
only_rank
:
_disable_global_capture
()
return
None
_CAPTURE
=
_RouterCapture
(
cfg
)
return
_CAPTURE
@
torch
.
no_grad
()
def
maybe_record_router_logits
(
*
,
layer_name
:
str
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
)
->
None
:
capture
=
_get_capture
()
if
capture
is
None
:
return
capture
.
record
(
layer_name
=
layer_name
,
router_logits
=
router_logits
,
top_k
=
top_k
)
def
maybe_flush_router_capture
(
*
,
reset
:
bool
=
False
)
->
Optional
[
str
]:
"""Flush capture buffers to disk without exiting the process."""
capture
=
_get_capture
()
if
capture
is
None
:
return
None
out_path
=
capture
.
flush
()
if
out_path
is
not
None
and
reset
:
capture
.
reset
()
return
out_path
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
View file @
425eb81e
...
...
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
make_unquantized_moe_kernel
,
select_unquantized_moe_backend
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
...
...
@@ -230,6 +231,87 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
# If Marlin W16A16 MoE is supported, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout.
#
# This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state.
if
(
getattr
(
layer
,
"_marlin_w16a16_moe_enabled"
,
False
)
and
current_platform
.
is_cuda_alike
()
and
not
getattr
(
layer
,
"use_nn_moe"
,
False
)
and
not
getattr
(
layer
,
"_marlin_w16a16_moe_packed"
,
False
)
):
w1
=
layer
.
w13_weight
w2
=
layer
.
w2_weight
if
(
w1
.
is_cuda
and
w2
.
is_cuda
and
w1
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
)
and
w2
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
)
):
try
:
if
w1
.
dim
()
!=
3
or
w2
.
dim
()
!=
3
or
w1
.
size
(
0
)
!=
w2
.
size
(
0
):
raise
RuntimeError
(
"Unexpected MoE weight shapes"
)
twoN
,
K
=
w1
.
size
(
1
),
w1
.
size
(
2
)
if
w2
.
size
(
1
)
!=
K
:
raise
RuntimeError
(
"Unexpected MoE w2 layout"
)
N
=
w2
.
size
(
2
)
if
twoN
!=
2
*
N
:
raise
RuntimeError
(
"Unexpected MoE hidden dims"
)
if
K
%
32
!=
0
or
N
%
16
!=
0
:
raise
RuntimeError
(
"Marlin packing requires alignment"
)
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.fused_moe.marlin_quant
import
(
w16a16_marlin_weight
,
)
def
_pack_per_expert
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_experts
=
weight
.
shape
[
0
]
packed0
=
w16a16_marlin_weight
(
weight
[
0
]).
contiguous
()
packed
=
packed0
.
new_empty
((
num_experts
,)
+
packed0
.
shape
)
packed
[
0
].
copy_
(
packed0
)
del
packed0
for
i
in
range
(
1
,
num_experts
):
tmp
=
w16a16_marlin_weight
(
weight
[
i
]).
contiguous
()
packed
[
i
].
copy_
(
tmp
)
del
tmp
return
packed
with
torch
.
no_grad
():
w1_packed
=
_pack_per_expert
(
w1
)
w2_packed
=
_pack_per_expert
(
w2
)
new_w1
=
Parameter
(
w1_packed
,
requires_grad
=
False
)
new_w2
=
Parameter
(
w2_packed
,
requires_grad
=
False
)
# Preserve any custom weight attributes (e.g. loaders).
if
hasattr
(
w1
,
"__dict__"
):
for
k
,
v
in
w1
.
__dict__
.
items
():
setattr
(
new_w1
,
k
,
v
)
if
hasattr
(
w2
,
"__dict__"
):
for
k
,
v
in
w2
.
__dict__
.
items
():
setattr
(
new_w2
,
k
,
v
)
setattr
(
new_w1
,
"marlin_w16a16_packed"
,
True
)
setattr
(
new_w2
,
"marlin_w16a16_packed"
,
True
)
layer
.
w13_weight
=
new_w1
layer
.
w2_weight
=
new_w2
layer
.
_marlin_w16a16_moe_packed
=
True
return
except
Exception
:
# If packing dependencies are unavailable, fall back to the
# standard (non-Marlin) layouts.
pass
# Padding the weight for better performance on ROCm
layer
.
w13_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w13_weight
.
data
)
layer
.
w2_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
)
...
...
@@ -289,7 +371,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
forward
(
layer
=
layer
,
...
...
@@ -297,7 +378,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
use_nn_moe
=
use_nn_moe
,
use_fused_gate
=
use_fused_gate
,
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
:
...
...
@@ -317,6 +397,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
getattr
(
layer
,
"_marlin_w16a16_moe_enabled"
,
False
)
and
getattr
(
layer
,
"_marlin_w16a16_moe_packed"
,
False
)
):
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
self
.
allow_inplace
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
quant_config
=
self
.
get_fused_moe_quant_config
(
layer
),
use_nn_moe
=
use_nn_moe
,
)
assert
self
.
kernel
is
not
None
return
self
.
kernel
(
hidden_states
=
x
,
...
...
vllm/model_executor/layers/mamba/abstract.py
View file @
425eb81e
...
...
@@ -43,7 +43,8 @@ class MambaBase(AttentionLayerBase):
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
|
None
:
if
(
vllm_config
.
speculative_config
is
not
None
and
vllm_config
.
model_config
.
hf_config
.
model_type
not
in
[
"qwen3_next"
]
and
vllm_config
.
model_config
.
hf_config
.
model_type
not
in
[
"qwen3_next"
,
"qwen3_5"
,
"qwen3_5_moe"
]
):
raise
NotImplementedError
(
"Mamba with speculative decoding is not supported yet."
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
425eb81e
...
...
@@ -44,6 +44,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionMetadata
import
vllm.envs
as
envs
# Added by the IBM Team, 2024
...
...
@@ -171,6 +172,15 @@ def mamba_v2_sharded_weight_loader(
# - track boundary of (sharded) param, and loaded_weight, respectively
boundary
,
loaded_boundary
=
0
,
0
if
envs
.
VLLM_USE_NN
:
loaded_total_dim
=
sum
(
full_dim
-
extra
for
full_dim
,
extra
,
_
in
shard_spec
)
param_out_axis
=
0
if
param
.
dim
()
==
1
else
(
param
.
dim
()
-
1
)
loaded_out_axis
=
0
if
(
loaded_weight
.
dim
()
>
1
and
loaded_weight
.
shape
[
-
1
]
==
loaded_total_dim
and
loaded_weight
.
shape
[
0
]
!=
loaded_total_dim
):
loaded_out_axis
=
loaded_weight
.
dim
()
-
1
# - iterate over the shard specs
for
full_dim
,
extra
,
duplicate_groups
in
shard_spec
:
# - full dim is the model dim (before TP).
...
...
@@ -201,6 +211,32 @@ def mamba_v2_sharded_weight_loader(
# - the ignore is for a mundane mypy error as it does not
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
if
envs
.
VLLM_USE_NN
:
if
take
>
0
:
param_slice
=
param
.
data
.
narrow
(
param_out_axis
,
boundary
,
take
)
loaded_slice
=
loaded_weight
.
narrow
(
loaded_out_axis
,
loaded_start_idx
,
take
)
if
(
param_slice
.
dim
()
==
loaded_slice
.
dim
()
+
1
and
param_slice
.
shape
[
1
]
==
1
):
loaded_slice
=
loaded_slice
.
unsqueeze
(
1
)
elif
(
loaded_slice
.
dim
()
==
param_slice
.
dim
()
+
1
and
loaded_slice
.
shape
[
1
]
==
1
):
loaded_slice
=
loaded_slice
.
squeeze
(
1
)
if
param_slice
.
shape
!=
loaded_slice
.
shape
:
loaded_slice
=
loaded_slice
.
permute
(
*
reversed
(
range
(
loaded_slice
.
dim
())))
if
param_slice
.
shape
!=
loaded_slice
.
shape
:
raise
RuntimeError
(
"mamba_v2_sharded_weight_loader shape mismatch: "
f
"param_slice=
{
tuple
(
param_slice
.
shape
)
}
"
f
"loaded_slice=
{
tuple
(
loaded_slice
.
shape
)
}
"
f
"(param_out_axis=
{
param_out_axis
}
, "
f
"loaded_out_axis=
{
loaded_out_axis
}
)"
)
param_slice
.
copy_
(
loaded_slice
)
else
:
param
.
data
[
boundary
:
(
boundary
+
take
),
...
# type: ignore[misc]
]
=
loaded_weight
[
...
...
@@ -428,6 +464,10 @@ class MambaMixer2(MambaBase, CustomOp):
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
# and `set_weight_attrs` doesn't allow to override it
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
if
envs
.
VLLM_USE_NN
:
conv_weights
=
self
.
conv1d
.
weight
.
squeeze
(
1
).
transpose
(
0
,
1
).
contiguous
()
else
:
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
)
)
...
...
vllm/model_executor/layers/mamba/mamba_utils.py
View file @
425eb81e
...
...
@@ -80,9 +80,13 @@ class MambaStateDtypeCalculator:
cls
,
model_dtype
:
ModelDType
|
torch
.
dtype
,
mamba_cache_dtype
:
MambaDType
,
mamba_ssm_cache_dtype
:
MambaDType
=
"auto"
,
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
state_dtype
=
get_kv_cache_torch_dtype
(
mamba_cache_dtype
,
model_dtype
)
return
(
state_dtype
,
state_dtype
)
# state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
# return (state_dtype, state_dtype)
return
cls
.
_mamba_state_dtype
(
model_dtype
,
mamba_cache_dtype
,
mamba_ssm_cache_dtype
)
@
classmethod
def
kda_state_dtype
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
425eb81e
...
...
@@ -1109,22 +1109,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
not
self
.
is_monolithic
assert
self
.
kernel
is
not
None
return
self
.
kernel
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
self
.
use_inplace
,
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
# TODO(rob): investigate the disable_expert_map introduced by:
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
quant_config
=
self
.
moe_quant_config
,
use_fused_gate
=
use_fused_gate
,
use_nn_moe
=
False
,
)
@
property
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
425eb81e
...
...
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
)
try
:
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
from
lmslim.layers.fused_moe.fuse_moe_fp8_marlin
import
fused_experts_impl_fp8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
...
...
@@ -35,9 +36,32 @@ logger = init_logger(__name__)
__all__
=
[
"CompressedTensorsW8A8Int8MarlinMoEMethod"
,
"CompressedTensorsW8A8FP8MarlinMoEMethod"
,
]
def
fp32_to_fp8_e4m3fn
(
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""更合理的FP32到Float8_e4m3fn转换,使用最近值而不是简单舍弃尾数"""
# torch.float8_e4m3fn的数值范围约[-448, 448]
fp8_min
,
fp8_max
=
-
448.0
,
448.0
t_clamped
=
t
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
)
# 保证不会下溢到0
# 转换前到float16再转fp8可能提升精度(float8实现本身通常通过float16做rounding)
t_fp16
=
t_clamped
.
to
(
torch
.
float16
)
return
t_fp16
.
to
(
torch
.
float8_e4m3fn
)
def
w8a8_fp8_nt_kpack2_marlin_weight
(
w8a8_w
,
# [size_n, size_k// 2 ]
k_tile
=
16
,
n_tile
=
16
,
):
size_n
,
size_k
=
w8a8_w
.
shape
assert
size_n
%
k_tile
==
0
and
size_k
%
n_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
n_tile
,
n_tile
,
size_k
//
k_tile
,
k_tile
))
w8a8_w
=
w8a8_w
.
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
return
w8a8_w
class
CompressedTensorsMarlinMoEMethod
(
FusedMoEMethodBase
):
def
__init_
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
(
moe
)
...
...
@@ -52,12 +76,191 @@ class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
if
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8FP8MarlinMoEMethod
(
quant_config
,
layer
.
moe_config
)
elif
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8MarlinMoEMethod
(
quant_config
,
layer
.
moe_config
)
else
:
raise
RuntimeError
(
f
"Slimquant_marlin does not support the FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
class
CompressedTensorsW8A8FP8MarlinMoEMethod
(
CompressedTensorsMarlinMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsMarlinConfig"
,
# type: ignore # noqa E501
moe
:
FusedMoEConfig
):
self
.
quant_config
=
quant_config
super
().
__init__
(
moe
)
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
if
not
per_channel
:
raise
ValueError
(
"For FP8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
if
self
.
static_input_scales
:
raise
ValueError
(
"For FP8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
self
.
fused_experts
=
self
.
fused_moe_forward
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
FusedMoEQuantConfig
]:
return
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
assert
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
assert
not
self
.
static_input_scales
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
.
float
()
if
w1_marlin_in
.
dtype
==
torch
.
float8_e4m3fn
else
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w1_marlin
=
fp32_to_fp8_e4m3fn
(
w1_marlin
)
del
w1_marlin_list
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
.
float
()
if
w2_marlin_in
.
dtype
==
torch
.
float8_e4m3fn
else
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
w2_marlin
=
fp32_to_fp8_e4m3fn
(
w2_marlin
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
fused_moe_forward
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
):
return
fused_experts_impl_fp8_marlin
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet."
)
return
self
.
fused_experts
(
layer
=
layer
,
x
=
x
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
activation
=
activation
,
routed_scaling_factor
=
routed_scaling_factor
,
shared_output
=
shared_output
,
)
class
CompressedTensorsW8A8Int8MarlinMoEMethod
(
CompressedTensorsMarlinMoEMethod
):
def
__init__
(
self
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
425eb81e
...
...
@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
typing
import
Optional
from
vllm
import
envs
import
torch
from
compressed_tensors.quantization
import
QuantizationArgs
,
QuantizationStrategy
from
torch.nn
import
Parameter
...
...
@@ -40,7 +41,6 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter
,
PerTensorScaleParameter
,
)
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
strategy_to_parameter_type
=
{
...
...
@@ -159,8 +159,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_channel_strategy
(
layer
.
weight
,
layer
.
weight_scale
,
getattr
(
layer
,
"input_scale"
,
None
)
)
weight
=
weight
.
t
()
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
weight
=
weight
.
t
().
contiguous
()
# triton不用转置,torch需要
# else:
# weight = weight.t()
elif
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
assert
self
.
is_static_input_scheme
is
False
weight
,
weight_scale
=
process_fp8_weight_block_strategy
(
...
...
@@ -193,6 +196,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
,
)
->
torch
.
Tensor
:
if
self
.
weight_block_size
is
not
None
:
return
self
.
w8a8_block_fp8_linear
.
apply
(
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
425eb81e
...
...
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional
import
torch
from
torch.nn
import
Module
from
torch.utils._python_dispatch
import
TorchDispatchMode
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
...
...
@@ -1027,20 +1026,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
kernel
is
not
None
assert
not
self
.
is_monolithic
return
self
.
kernel
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
self
.
use_inplace
,
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
quant_config
=
self
.
moe_quant_config
,
use_fused_gate
=
use_fused_gate
,
use_nn_moe
=
False
,
)
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py
View file @
425eb81e
...
...
@@ -12,7 +12,12 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
)
try
:
from
lmslim.quantize.quant_ops
import
hipblaslt_w8a8_channelwise_gemm
from
lmslim.layers.gemm.fp8_utils
import
triton_scaled_mm_fp8
except
ImportError
:
print
(
"INFO: Please updata lmslim if you want to use fp8_utils.
\n
"
)
from
vllm
import
envs
class
TorchFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
"""
...
...
@@ -176,46 +181,31 @@ class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
dummy_tensor
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
m
=
A
.
shape
[
0
]
k
=
A
.
shape
[
1
]
n
=
B
.
shape
[
0
]
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
_
,
output
=
hipblaslt_w8a8_channelwise_gemm
(
a
=
A
,
b
=
B
,
scale_a
=
As
,
scale_b
=
Bs
,
m
=
m
,
n
=
n
,
k
=
k
,
transpose_flag
=
"NT"
,
out_dtype
=
out_dtype
,
bias
=
bias
,
)
return
output
.
view
(
m
,
n
)
else
:
output
=
triton_scaled_mm_fp8
(
A
,
B
,
scale_a
=
dummy_tensor
,
scale_b
=
dummy_tensor
,
out_dtype
=
torch
.
float32
,
scale_a
=
As
,
scale_b
=
Bs
,
out_dtype
=
out_dtype
,
bias
=
bias
,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
output_shape
[
0
])
x_scale
=
torch
.
narrow
(
As
,
0
,
0
,
output_shape
[
0
])
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
Bs
.
t
()
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
).
view
(
*
output_shape
)
return
output
.
view
(
*
output_shape
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment