Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
42135d68
"vscode:/vscode.git/clone" did not exist on "aeb37c2a725554791ff6f258b1e18830867a3ab9"
Unverified
Commit
42135d68
authored
Jan 21, 2026
by
Robert Shaw
Committed by
GitHub
Jan 21, 2026
Browse files
[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)
parent
e14467be
Changes
82
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
267 additions
and
177 deletions
+267
-177
tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml
...e-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml
+2
-0
tests/kernels/moe/modular_kernel_tools/common.py
tests/kernels/moe/modular_kernel_tools/common.py
+5
-0
tests/kernels/moe/modular_kernel_tools/mk_objects.py
tests/kernels/moe/modular_kernel_tools/mk_objects.py
+17
-75
tests/kernels/moe/test_batched_deepgemm.py
tests/kernels/moe/test_batched_deepgemm.py
+3
-0
tests/kernels/moe/test_block_fp8.py
tests/kernels/moe/test_block_fp8.py
+39
-3
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+16
-13
tests/kernels/moe/test_deepep_deepgemm_moe.py
tests/kernels/moe/test_deepep_deepgemm_moe.py
+6
-5
tests/kernels/moe/test_deepep_moe.py
tests/kernels/moe/test_deepep_moe.py
+11
-2
tests/kernels/moe/test_deepgemm.py
tests/kernels/moe/test_deepgemm.py
+26
-13
tests/kernels/moe/test_flashinfer.py
tests/kernels/moe/test_flashinfer.py
+25
-20
tests/kernels/moe/test_flashinfer_moe.py
tests/kernels/moe/test_flashinfer_moe.py
+29
-5
tests/kernels/moe/test_modular_oai_triton_moe.py
tests/kernels/moe/test_modular_oai_triton_moe.py
+4
-2
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+3
-3
tests/kernels/moe/test_nvfp4_moe.py
tests/kernels/moe/test_nvfp4_moe.py
+2
-3
tests/kernels/moe/test_pplx_cutlass_moe.py
tests/kernels/moe/test_pplx_cutlass_moe.py
+26
-21
tests/kernels/moe/test_pplx_moe.py
tests/kernels/moe/test_pplx_moe.py
+2
-0
tests/kernels/moe/test_routing.py
tests/kernels/moe/test_routing.py
+0
-7
tests/kernels/moe/test_triton_moe_no_act_mul.py
tests/kernels/moe/test_triton_moe_no_act_mul.py
+13
-3
tests/kernels/moe/utils.py
tests/kernels/moe/utils.py
+36
-1
tools/vllm-rocm/pin_rocm_dependencies.py
tools/vllm-rocm/pin_rocm_dependencies.py
+2
-1
No files found.
tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml
View file @
42135d68
...
@@ -3,3 +3,5 @@ accuracy_threshold: 0.88
...
@@ -3,3 +3,5 @@ accuracy_threshold: 0.88
num_questions
:
1319
num_questions
:
1319
num_fewshot
:
5
num_fewshot
:
5
server_args
:
"
--enforce-eager
--max-model-len
8192
--tensor-parallel-size
2"
server_args
:
"
--enforce-eager
--max-model-len
8192
--tensor-parallel-size
2"
env
:
VLLM_USE_FLASHINFER_MOE_FP4
:
"
0"
tests/kernels/moe/modular_kernel_tools/common.py
View file @
42135d68
...
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
)
)
from
vllm.utils.import_utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils.import_utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
...
@@ -574,10 +575,14 @@ def make_modular_kernel(
...
@@ -574,10 +575,14 @@ def make_modular_kernel(
num_experts
=
config
.
E
,
num_experts
=
config
.
E
,
experts_per_token
=
config
.
topk
,
experts_per_token
=
config
.
topk
,
hidden_dim
=
config
.
K
,
hidden_dim
=
config
.
K
,
intermediate_size_per_partition
=
config
.
N
,
num_local_experts
=
config
.
num_local_experts
,
num_local_experts
=
config
.
num_local_experts
,
moe_parallel_config
=
moe_parallel_config
,
moe_parallel_config
=
moe_parallel_config
,
in_dtype
=
config
.
dtype
,
in_dtype
=
config
.
dtype
,
max_num_tokens
=
next_power_of_2
(
config
.
M
),
max_num_tokens
=
next_power_of_2
(
config
.
M
),
activation
=
"silu"
,
device
=
vllm_config
.
device_config
.
device
,
routing_method
=
RoutingMethodType
.
DeepSeekV3
,
)
)
# make modular kernel
# make modular kernel
...
...
tests/kernels/moe/modular_kernel_tools/mk_objects.py
View file @
42135d68
...
@@ -425,84 +425,26 @@ def make_fused_experts(
...
@@ -425,84 +425,26 @@ def make_fused_experts(
num_dispatchers
:
int
,
num_dispatchers
:
int
,
N
:
int
,
N
:
int
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
batch_kwargs
=
{
if
(
"max_num_tokens"
:
moe
.
max_num_tokens
,
fused_experts_type
.
activation_format
()
"num_dispatchers"
:
num_dispatchers
,
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
}
):
quant_kwargs
=
{
"quant_config"
:
quant_config
,
}
deepgemm_kwargs
=
{
"allow_deep_gemm"
:
has_deep_gemm
()}
torch
.
set_printoptions
(
threshold
=
0
,
edgeitems
=
0
,
linewidth
=
10000
)
if
fused_experts_type
==
BatchedDeepGemmExperts
:
kwargs
=
batch_kwargs
|
quant_kwargs
print
(
f
"Making BatchedDeepGemmExperts
{
kwargs
}
..."
)
experts
=
BatchedDeepGemmExperts
(
**
kwargs
)
elif
fused_experts_type
==
BatchedTritonExperts
:
kwargs
=
batch_kwargs
|
quant_kwargs
print
(
f
"Making BatchedTritonExperts
{
kwargs
}
..."
)
experts
=
BatchedTritonExperts
(
**
kwargs
)
elif
fused_experts_type
==
DeepGemmExperts
:
print
(
f
"Making DeepGemmExperts
{
quant_config
}
..."
)
experts
=
DeepGemmExperts
(
quant_config
)
elif
fused_experts_type
==
TritonExperts
:
kwargs
=
quant_kwargs
print
(
f
"Making TritonExperts
{
kwargs
}
..."
)
experts
=
TritonExperts
(
**
kwargs
)
elif
fused_experts_type
==
TritonOrDeepGemmExperts
:
kwargs
=
quant_kwargs
|
deepgemm_kwargs
print
(
f
"Making TritonOrDeepGemmExperts
{
kwargs
}
..."
)
experts
=
TritonOrDeepGemmExperts
(
**
kwargs
)
elif
fused_experts_type
==
NaiveBatchedExperts
:
kwargs
=
batch_kwargs
|
quant_kwargs
print
(
f
"Making NaiveBatchedExperts
{
kwargs
}
..."
)
experts
=
NaiveBatchedExperts
(
**
kwargs
)
elif
fused_experts_type
==
CutlassExpertsFp8
:
strides
=
make_cutlass_strides
(
moe
.
num_experts
,
N
,
moe
.
hidden_dim
)
kwargs
=
{
"out_dtype"
:
moe
.
in_dtype
,
"ab_strides1"
:
strides
[
0
],
"ab_strides2"
:
strides
[
1
],
"c_strides1"
:
strides
[
2
],
"c_strides2"
:
strides
[
3
],
}
|
quant_kwargs
print
(
f
"Making CutlassExpertsFp8
{
kwargs
}
..."
)
experts
=
CutlassExpertsFp8
(
**
kwargs
)
elif
fused_experts_type
==
CutlassBatchedExpertsFp8
:
strides
=
make_cutlass_strides
(
moe
.
num_experts
,
N
,
moe
.
hidden_dim
)
kwargs
=
{
"max_experts_per_worker"
:
moe
.
num_local_experts
,
"num_dispatchers"
:
num_dispatchers
,
"out_dtype"
:
moe
.
in_dtype
,
"ab_strides1"
:
strides
[
0
],
"ab_strides2"
:
strides
[
1
],
"c_strides1"
:
strides
[
2
],
"c_strides2"
:
strides
[
3
],
}
|
quant_kwargs
print
(
f
"Making CutlassBatchedExpertsFp8
{
kwargs
}
..."
)
experts
=
CutlassBatchedExpertsFp8
(
**
kwargs
)
elif
fused_experts_type
==
CutlassExpertsFp4
:
kwargs
=
{
kwargs
=
{
"max_experts_per_worker"
:
moe
.
num_local_experts
,
"moe_config"
:
moe
,
"quant_config"
:
quant_config
,
"max_num_tokens"
:
moe
.
max_num_tokens
,
"num_dispatchers"
:
num_dispatchers
,
"num_dispatchers"
:
num_dispatchers
,
"out_dtype"
:
moe
.
in_dtype
,
}
}
|
quant_kwargs
print
(
f
"Making CutlassExpertsFp4
{
kwargs
}
..."
)
experts
=
CutlassExpertsFp4
(
**
kwargs
)
elif
fused_experts_type
==
FlashInferExperts
:
kwargs
=
{
"out_dtype"
:
moe
.
in_dtype
,
"ep_rank"
:
moe
.
ep_rank
,
"ep_size"
:
moe
.
ep_size
,
"tp_rank"
:
moe
.
tp_rank
,
"tp_size"
:
moe
.
tp_size
,
}
|
quant_kwargs
print
(
f
"Making FlashInferExperts
{
kwargs
}
..."
)
experts
=
FlashInferExperts
(
**
kwargs
)
else
:
else
:
raise
RuntimeError
(
f
"Unknown fused experts type:
{
fused_experts_type
}
"
)
kwargs
=
{
"moe_config"
:
moe
,
"quant_config"
:
quant_config
,
}
torch
.
set_printoptions
(
threshold
=
0
,
edgeitems
=
0
,
linewidth
=
10000
)
print
(
f
"Making
{
fused_experts_type
.
__class__
.
__name__
}
{
kwargs
}
..."
)
experts
=
fused_experts_type
(
**
kwargs
)
torch
.
set_printoptions
(
threshold
=
1000
,
edgeitems
=
5
,
linewidth
=
80
)
torch
.
set_printoptions
(
threshold
=
1000
,
edgeitems
=
5
,
linewidth
=
80
)
...
...
tests/kernels/moe/test_batched_deepgemm.py
View file @
42135d68
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularK
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularK
from
vllm.utils.deep_gemm
import
calc_diff
,
is_deep_gemm_supported
from
vllm.utils.deep_gemm
import
calc_diff
,
is_deep_gemm_supported
from
.test_deepgemm
import
make_block_quant_fp8_weights
from
.test_deepgemm
import
make_block_quant_fp8_weights
from
.utils
import
make_dummy_moe_config
BLOCK_SIZE
=
[
128
,
128
]
BLOCK_SIZE
=
[
128
,
128
]
...
@@ -71,6 +72,7 @@ def test_batched_deepgemm_vs_triton(
...
@@ -71,6 +72,7 @@ def test_batched_deepgemm_vs_triton(
max_num_tokens
=
max_num_tokens
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
1
,
num_dispatchers
=
1
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
moe_config
=
make_dummy_moe_config
(),
)
)
mk_triton
=
FusedMoEModularKernel
(
prep_finalize
,
triton_experts
)
mk_triton
=
FusedMoEModularKernel
(
prep_finalize
,
triton_experts
)
...
@@ -89,6 +91,7 @@ def test_batched_deepgemm_vs_triton(
...
@@ -89,6 +91,7 @@ def test_batched_deepgemm_vs_triton(
max_num_tokens
=
max_num_tokens
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
1
,
num_dispatchers
=
1
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
moe_config
=
make_dummy_moe_config
(),
)
)
mk_deepgemm
=
FusedMoEModularKernel
(
prep_finalize
,
deepgemm_experts
)
mk_deepgemm
=
FusedMoEModularKernel
(
prep_finalize
,
deepgemm_experts
)
...
...
tests/kernels/moe/test_block_fp8.py
View file @
42135d68
...
@@ -4,7 +4,12 @@
...
@@ -4,7 +4,12 @@
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.moe.utils
import
make_test_quant_config
,
make_test_weights
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
tests.kernels.moe.utils
import
(
make_dummy_moe_config
,
make_test_quant_config
,
make_test_weights
,
)
from
tests.kernels.quant_utils
import
(
from
tests.kernels.quant_utils
import
(
native_per_token_group_quant_fp8
,
native_per_token_group_quant_fp8
,
native_w8a8_block_matmul
,
native_w8a8_block_matmul
,
...
@@ -15,13 +20,21 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -15,13 +20,21 @@ from vllm.model_executor.layers.fused_moe import (
fused_experts
,
fused_experts
,
fused_topk
,
fused_topk
,
)
)
from
vllm.model_executor.layers.fused_moe.config
import
(
fp8_w8a8_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_valid_deep_gemm_shape
,
_valid_deep_gemm_shape
,
deep_gemm_moe_fp8
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
modular_triton_fused_moe
,
modular_triton_fused_moe
,
)
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
(
from
vllm.utils.deep_gemm
import
(
get_mk_alignment_for_contiguous_layout
,
get_mk_alignment_for_contiguous_layout
,
...
@@ -161,7 +174,7 @@ def test_w8a8_block_fp8_fused_moe(
...
@@ -161,7 +174,7 @@ def test_w8a8_block_fp8_fused_moe(
block_shape
=
block_size
,
block_shape
=
block_size
,
)
)
m_fused_moe
=
modular_triton_fused_moe
(
quant_config
)
m_fused_moe
=
modular_triton_fused_moe
(
make_dummy_moe_config
(),
quant_config
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
...
@@ -236,6 +249,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
...
@@ -236,6 +249,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
quant_config
=
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
block_shape
=
block_size
,
)
deep_gemm_experts
=
mk
.
FusedMoEModularKernel
(
prepare_finalize
=
MoEPrepareAndFinalizeNoEP
(),
fused_experts
=
TritonOrDeepGemmExperts
(
moe_config
=
make_dummy_moe_config
(),
quant_config
=
quant_config
,
),
)
def
deep_gemm_moe_fp8
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
):
return
deep_gemm_experts
(
hidden_states
=
a
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
)
# Set the context to avoid lots of warning spam.
# Set the context to avoid lots of warning spam.
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
ref_out
=
torch_w8a8_block_fp8_moe
(
ref_out
=
torch_w8a8_block_fp8_moe
(
...
...
tests/kernels/moe/test_cutlass_moe.py
View file @
42135d68
...
@@ -8,6 +8,7 @@ import pytest
...
@@ -8,6 +8,7 @@ import pytest
import
torch
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
tests.kernels.moe.utils
import
make_dummy_moe_config
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe
import
fused_experts
,
fused_topk
...
@@ -193,16 +194,18 @@ def run_with_expert_maps(
...
@@ -193,16 +194,18 @@ def run_with_expert_maps(
out_tensor
=
torch
.
zeros_like
(
cutlass_moe_kwargs
[
"hidden_states"
])
out_tensor
=
torch
.
zeros_like
(
cutlass_moe_kwargs
[
"hidden_states"
])
for
kwargs
,
new_quant_config
in
slice_experts
():
for
kwargs
,
new_quant_config
in
slice_experts
():
w2
=
kwargs
[
"w2"
]
a
=
kwargs
[
"hidden_states"
]
kernel
=
mk
.
FusedMoEModularKernel
(
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
MoEPrepareAndFinalizeNoEP
(),
CutlassExpertsFp8
(
CutlassExpertsFp8
(
out_dtype
=
kwargs
[
"hidden_states"
].
dtype
,
moe_config
=
make_dummy_moe_config
(
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
num_experts
=
w2
.
shape
[
0
],
e
=
kwargs
[
"w2"
].
shape
[
0
],
# type: ignore[union-attr]
hidden_dim
=
w2
.
shape
[
1
],
n
=
kwargs
[
"w2"
].
shape
[
2
],
# type: ignore[union-attr]
intermediate_size_per_partition
=
w2
.
shape
[
2
],
k
=
kwargs
[
"w2"
].
shape
[
1
],
# type: ignore[union-attr]
in_dtype
=
a
.
dtype
,
),
quant_config
=
new_quant_config
,
quant_config
=
new_quant_config
,
device
=
"cuda"
,
),
),
)
)
out_tensor
=
out_tensor
+
kernel
(
**
kwargs
)
out_tensor
=
out_tensor
+
kernel
(
**
kwargs
)
...
@@ -249,19 +252,19 @@ def run_8_bit(
...
@@ -249,19 +252,19 @@ def run_8_bit(
"topk_ids"
:
topk_ids
,
"topk_ids"
:
topk_ids
,
}
}
num_experts
=
moe_tensors
.
w1
.
size
(
0
)
num_experts
=
moe_tensors
.
w1
.
size
(
0
)
# type: ignore[attr-defined]
with_ep
=
num_local_experts
is
not
None
or
num_local_experts
==
num_experts
with_ep
=
num_local_experts
is
not
None
or
num_local_experts
==
num_experts
if
not
with_ep
:
if
not
with_ep
:
kernel
=
mk
.
FusedMoEModularKernel
(
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
MoEPrepareAndFinalizeNoEP
(),
CutlassExpertsFp8
(
CutlassExpertsFp8
(
out_dtype
=
moe_tensors
.
a
.
dtype
,
moe_config
=
make_dummy_moe_config
(
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
num_experts
=
moe_tensors
.
w2_q
.
shape
[
0
],
# type: ignore[union-attr]
e
=
moe_tensors
.
w2_q
.
shape
[
0
],
# type: ignore[union-attr]
hidden_dim
=
moe_tensors
.
w2_q
.
shape
[
1
],
# type: ignore[union-attr]
n
=
moe_tensors
.
w2_q
.
shape
[
2
],
# type: ignore[union-attr]
intermediate_size_per_partition
=
moe_tensors
.
w2_q
.
shape
[
2
],
# type: ignore[union-attr]
k
=
moe_tensors
.
w2_q
.
shape
[
1
],
# type: ignore[union-attr]
in_dtype
=
moe_tensors
.
a
.
dtype
,
),
quant_config
=
quant_config
,
quant_config
=
quant_config
,
device
=
"cuda"
,
),
),
)
)
return
kernel
(
**
kwargs
)
return
kernel
(
**
kwargs
)
...
...
tests/kernels/moe/test_deepep_deepgemm_moe.py
View file @
42135d68
...
@@ -33,7 +33,7 @@ from vllm.v1.worker.workspace import init_workspace_manager
...
@@ -33,7 +33,7 @@ from vllm.v1.worker.workspace import init_workspace_manager
from
...utils
import
multi_gpu_test
from
...utils
import
multi_gpu_test
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.utils
import
make_test_weights
from
.utils
import
make_dummy_moe_config
,
make_test_weights
if
has_deep_ep
():
if
has_deep_ep
():
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
...
@@ -192,6 +192,7 @@ def make_ll_modular_kernel(
...
@@ -192,6 +192,7 @@ def make_ll_modular_kernel(
max_num_tokens
=
max_tokens_per_rank
,
max_num_tokens
=
max_tokens_per_rank
,
num_dispatchers
=
pgi
.
world_size
//
dp_size
,
num_dispatchers
=
pgi
.
world_size
//
dp_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
moe_config
=
make_dummy_moe_config
(),
)
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
fused_experts
=
fused_experts
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
fused_experts
=
fused_experts
)
return
mk
return
mk
...
@@ -219,7 +220,10 @@ def make_ht_modular_kernel(
...
@@ -219,7 +220,10 @@ def make_ht_modular_kernel(
block_shape
=
test_config
.
block_size
,
block_shape
=
test_config
.
block_size
,
)
)
fused_experts
=
DeepGemmExperts
(
quant_config
)
fused_experts
=
DeepGemmExperts
(
moe_config
=
make_dummy_moe_config
(),
quant_config
=
quant_config
,
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
fused_experts
=
fused_experts
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
fused_experts
=
fused_experts
)
return
mk
return
mk
...
@@ -349,9 +353,6 @@ def triton_impl(
...
@@ -349,9 +353,6 @@ def triton_impl(
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
False
,
inplace
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
# Make sure this is set to False so we
# don't end up comparing the same implementation.
allow_deep_gemm
=
False
,
)
)
...
...
tests/kernels/moe/test_deepep_moe.py
View file @
42135d68
...
@@ -10,11 +10,14 @@ import pytest
...
@@ -10,11 +10,14 @@ import pytest
import
torch.distributed
import
torch.distributed
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
tests.kernels.moe.utils
import
make_dummy_moe_config
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
TritonExperts
from
vllm.model_executor.layers.fused_moe
import
TritonExperts
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
BatchedTritonExperts
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
BatchedTritonExperts
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
...
@@ -160,15 +163,21 @@ def make_modular_kernel(
...
@@ -160,15 +163,21 @@ def make_modular_kernel(
num_dispatchers
=
pgi
.
world_size
//
dp_size
num_dispatchers
=
pgi
.
world_size
//
dp_size
moe_config
=
make_dummy_moe_config
()
if
low_latency_mode
:
if
low_latency_mode
:
assert
not
quant_config
.
per_act_token_quant
,
"not supported in ll mode"
assert
not
quant_config
.
per_act_token_quant
,
"not supported in ll mode"
fused_experts
=
BatchedTritonExperts
(
fused_experts
=
BatchedTritonExperts
(
max_num_tokens
=
MAX_TOKENS_PER_RANK
,
max_num_tokens
=
MAX_TOKENS_PER_RANK
,
num_dispatchers
=
num_dispatchers
,
num_dispatchers
=
num_dispatchers
,
moe_config
=
moe_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
else
:
else
:
fused_experts
=
TritonExperts
(
quant_config
=
quant_config
)
fused_experts
=
TritonExperts
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
fused_experts
=
fused_experts
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
fused_experts
=
fused_experts
)
return
mk
return
mk
...
...
tests/kernels/moe/test_deepgemm.py
View file @
42135d68
...
@@ -11,10 +11,19 @@ import math
...
@@ -11,10 +11,19 @@ import math
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.layers.fused_moe.config
import
fp8_w8a8_moe_quant_config
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
tests.kernels.moe.utils
import
make_dummy_moe_config
from
vllm.model_executor.layers.fused_moe.config
import
(
fp8_w8a8_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
,
per_token_group_quant_fp8
,
)
)
...
@@ -100,6 +109,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
...
@@ -100,6 +109,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
block_shape
=
block_size
,
block_shape
=
block_size
,
)
)
deep_gemm_experts
=
mk
.
FusedMoEModularKernel
(
prepare_finalize
=
MoEPrepareAndFinalizeNoEP
(),
fused_experts
=
TritonOrDeepGemmExperts
(
moe_config
=
make_dummy_moe_config
(),
quant_config
=
quant_config
,
),
)
# triton reference
# triton reference
out_triton
=
fused_experts
(
out_triton
=
fused_experts
(
hidden_states
=
tokens_bf16
,
hidden_states
=
tokens_bf16
,
...
@@ -109,19 +126,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
...
@@ -109,19 +126,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
False
,
inplace
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
allow_deep_gemm
=
False
,
)
)
# DeepGemm
# DeepGemm
out_deepgemm
=
fused
_experts
(
out_deepgemm
=
deep_gemm
_experts
(
hidden_states
=
tokens_bf16
,
hidden_states
=
tokens_bf16
,
w1
=
w1
,
w1
=
w1
,
w2
=
w2
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
False
,
inplace
=
False
,
quant_config
=
quant_config
,
allow_deep_gemm
=
True
,
)
)
diff
=
calc_diff
(
out_deepgemm
,
out_triton
)
diff
=
calc_diff
(
out_deepgemm
,
out_triton
)
assert
diff
<
0.001
,
f
"Diff exceeded 1%:
{
diff
}
"
assert
diff
<
0.001
,
f
"Diff exceeded 1%:
{
diff
}
"
...
@@ -147,20 +161,19 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i
...
@@ -147,20 +161,19 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i
with
monkeypatch
.
context
()
as
mp
:
with
monkeypatch
.
context
()
as
mp
:
mp
.
setenv
(
"VLLM_USE_DEEP_GEMM"
,
"1"
)
mp
.
setenv
(
"VLLM_USE_DEEP_GEMM"
,
"1"
)
_
fused_moe_mod
=
importlib
.
import_module
(
_
DeepGemmExperts
=
importlib
.
import_module
(
"vllm.model_executor.layers.fused_moe.
fused
_moe"
"vllm.model_executor.layers.fused_moe.
deep_gemm
_moe"
)
)
.
DeepGemmExperts
call_counter
=
{
"cnt"
:
0
}
call_counter
=
{
"cnt"
:
0
}
orig_fn
=
_
fused_moe_mod
.
d
eep
_g
emm
_moe_fp8
orig_fn
=
_
D
eep
G
emm
Experts
.
apply
def
_spy_
deep_gemm_moe_fp8
(
*
args
,
**
kwargs
):
def
_spy_
apply
(
*
args
,
**
kwargs
):
call_counter
[
"cnt"
]
+=
1
call_counter
[
"cnt"
]
+=
1
return
orig_fn
(
*
args
,
**
kwargs
)
return
orig_fn
(
*
args
,
**
kwargs
)
monkeypatch
.
setattr
(
_fused_moe_mod
,
"deep_gemm_moe_fp8"
,
_spy_deep_gemm_moe_fp8
)
monkeypatch
.
setattr
(
_DeepGemmExperts
,
"apply"
,
_spy_apply
)
if
topk
>
num_experts
:
if
topk
>
num_experts
:
pytest
.
skip
(
f
"topk=
{
topk
}
> num_experts=
{
num_experts
}
"
)
pytest
.
skip
(
f
"topk=
{
topk
}
> num_experts=
{
num_experts
}
"
)
...
...
tests/kernels/moe/test_flashinfer.py
View file @
42135d68
...
@@ -8,7 +8,10 @@ import torch
...
@@ -8,7 +8,10 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
fp8_w8a8_moe_quant_config
,
fp8_w8a8_moe_quant_config
,
)
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
...
@@ -116,18 +119,7 @@ class TestData:
...
@@ -116,18 +119,7 @@ class TestData:
layer
.
w13_weight_scale
=
w13_weight_scale
layer
.
w13_weight_scale
=
w13_weight_scale
layer
.
w2_weight_scale
=
w2_weight_scale
layer
.
w2_weight_scale
=
w2_weight_scale
# Setup dummy config.
# Setup dummy config.
layer
.
moe_parallel_config
=
mk
.
FusedMoEParallelConfig
(
layer
.
moe_parallel_config
=
mk
.
FusedMoEParallelConfig
.
make_no_parallel
()
tp_size
=
1
,
pcp_size
=
1
,
dp_size
=
1
,
ep_size
=
1
,
tp_rank
=
0
,
pcp_rank
=
0
,
dp_rank
=
0
,
ep_rank
=
0
,
use_ep
=
False
,
all2all_backend
=
"naive"
,
)
# flashinfer expects swapped rows for w13
# flashinfer expects swapped rows for w13
layer
.
w13_weight
.
data
=
swap_w13_to_w31
(
layer
.
w13_weight
.
data
)
layer
.
w13_weight
.
data
=
swap_w13_to_w31
(
layer
.
w13_weight
.
data
)
...
@@ -238,6 +230,8 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
...
@@ -238,6 +230,8 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
):
):
set_random_seed
(
7
)
set_random_seed
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
assert
activation
in
[
"silu"
,
"relu2_no_mul"
]
is_act_and_mul
=
activation
==
"silu_and_mul"
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
td
=
TestData
.
make_moe_tensors_8bit
(
td
=
TestData
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
is_trtllm
=
False
,
activation
=
activation
m
,
k
,
n
,
e
,
is_trtllm
=
False
,
activation
=
activation
...
@@ -285,19 +279,30 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
...
@@ -285,19 +279,30 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td
.
layer
.
get_fused_moe_quant_config
=
get_fused_moe_quant_config
td
.
layer
.
get_fused_moe_quant_config
=
get_fused_moe_quant_config
td
.
layer
.
quant_method
=
td
.
layer
td
.
layer
.
quant_method
=
td
.
layer
moe_config
=
FusedMoEConfig
(
num_experts
=
e
,
experts_per_token
=
topk
,
hidden_dim
=
k
,
intermediate_size_per_partition
=
n
,
num_local_experts
=
e
,
activation
=
activation
,
device
=
"cuda"
,
moe_parallel_config
=
FusedMoEParallelConfig
.
make_no_parallel
(),
in_dtype
=
torch
.
bfloat16
,
is_act_and_mul
=
is_act_and_mul
,
routing_method
=
RoutingMethodType
.
TopK
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
quant_config
.
is_block_quantized
defer_input_quant
=
FlashInferExperts
.
expects_unquantized_inputs
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
)
),
),
FlashInferExperts
(
FlashInferExperts
(
out_dtype
=
td
.
layer
.
orig_dtype
,
moe_config
=
moe_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
ep_rank
=
td
.
layer
.
moe_parallel_config
.
ep_rank
,
ep_size
=
td
.
layer
.
moe_parallel_config
.
ep_size
,
tp_rank
=
td
.
layer
.
moe_parallel_config
.
tp_rank
,
tp_size
=
td
.
layer
.
moe_parallel_config
.
tp_size
,
use_dp
=
False
,
use_deepseek_fp8_block_scale
=
False
,
),
),
)
)
...
...
tests/kernels/moe/test_flashinfer_moe.py
View file @
42135d68
...
@@ -13,14 +13,19 @@ from tests.kernels.utils import torch_moe
...
@@ -13,14 +13,19 @@ from tests.kernels.utils import torch_moe
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
RoutingMethodType
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
FlashInferExperts
,
is_valid_flashinfer_cutlass_fused_moe
,
is_valid_flashinfer_cutlass_fused_moe
,
)
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize
import
(
create_flashinfer_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
...
@@ -86,9 +91,28 @@ def test_flashinfer_fp4_moe_no_graph(
...
@@ -86,9 +91,28 @@ def test_flashinfer_fp4_moe_no_graph(
assert
is_valid_flashinfer_cutlass_fused_moe
(
a
,
w1_q
,
w2_q
)
assert
is_valid_flashinfer_cutlass_fused_moe
(
a
,
w1_q
,
w2_q
)
moe_config
=
FusedMoEConfig
(
num_experts
=
e
,
experts_per_token
=
topk
,
hidden_dim
=
k
,
intermediate_size_per_partition
=
n
,
num_local_experts
=
e
,
activation
=
activation
,
device
=
"cuda"
,
moe_parallel_config
=
FusedMoEParallelConfig
.
make_no_parallel
(),
in_dtype
=
dtype
,
is_act_and_mul
=
is_gated_act
,
routing_method
=
RoutingMethodType
.
TopK
,
)
flashinfer_experts
=
FusedMoEModularKernel
(
flashinfer_experts
=
FusedMoEModularKernel
(
create_flashinfer_prepare_finalize
(
use_dp
=
False
,
use_nvfp4
=
True
),
MoEPrepareAndFinalizeNoEP
(
FlashInferExperts
(
out_dtype
=
dtype
,
quant_config
=
quant_config
),
defer_input_quant
=
FlashInferExperts
.
expects_unquantized_inputs
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
)
),
FlashInferExperts
(
moe_config
=
moe_config
,
quant_config
=
quant_config
),
)
)
fi_activation
=
{
"silu_and_mul"
:
"silu"
,
"relu2"
:
"relu2_no_mul"
}[
activation
]
fi_activation
=
{
"silu_and_mul"
:
"silu"
,
"relu2"
:
"relu2_no_mul"
}[
activation
]
...
...
tests/kernels/moe/test_modular_oai_triton_moe.py
View file @
42135d68
...
@@ -36,6 +36,8 @@ from vllm.model_executor.layers.utils import shuffle_weight
...
@@ -36,6 +36,8 @@ from vllm.model_executor.layers.utils import shuffle_weight
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
from
.utils
import
make_dummy_moe_config
MNK
=
[
MNK
=
[
(
1
,
512
,
384
),
(
1
,
512
,
384
),
(
1
,
2880
,
2880
),
(
1
,
2880
,
2880
),
...
@@ -174,9 +176,9 @@ def oai_triton_moe_impl(
...
@@ -174,9 +176,9 @@ def oai_triton_moe_impl(
)
)
if
unfused
:
if
unfused
:
fused_experts
=
UnfusedOAITritonExperts
(
quant_config
)
fused_experts
=
UnfusedOAITritonExperts
(
make_dummy_moe_config
(),
quant_config
)
else
:
else
:
fused_experts
=
OAITritonExperts
(
quant_config
)
fused_experts
=
OAITritonExperts
(
make_dummy_moe_config
(),
quant_config
)
mk
=
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
fused_experts
)
mk
=
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
fused_experts
)
...
...
tests/kernels/moe/test_moe.py
View file @
42135d68
...
@@ -18,7 +18,7 @@ from transformers import MixtralConfig
...
@@ -18,7 +18,7 @@ from transformers import MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
import
vllm.model_executor.layers.fused_moe
# noqa
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.moe.utils
import
fused_moe
from
tests.kernels.moe.utils
import
fused_moe
,
make_dummy_moe_config
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_experts
,
torch_moe
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_experts
,
torch_moe
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
...
@@ -332,7 +332,7 @@ def test_fused_moe(
...
@@ -332,7 +332,7 @@ def test_fused_moe(
#
#
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
m_fused_moe_fn
=
modular_triton_fused_moe
(
quant_config
)
m_fused_moe_fn
=
modular_triton_fused_moe
(
make_dummy_moe_config
(),
quant_config
)
def
m_fused_moe
(
def
m_fused_moe
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
...
@@ -437,7 +437,7 @@ def test_naive_block_assignment_moe(
...
@@ -437,7 +437,7 @@ def test_naive_block_assignment_moe(
#
#
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
m_fused_moe_fn
=
modular_triton_fused_moe
(
quant_config
)
m_fused_moe_fn
=
modular_triton_fused_moe
(
make_dummy_moe_config
(),
quant_config
)
def
m_fused_moe
(
def
m_fused_moe
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
...
...
tests/kernels/moe/test_nvfp4_moe.py
View file @
42135d68
...
@@ -4,7 +4,7 @@ import pytest
...
@@ -4,7 +4,7 @@ import pytest
import
torch
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
tests.kernels.moe.utils
import
make_test_weights
from
tests.kernels.moe.utils
import
make_dummy_moe_config
,
make_test_weights
from
tests.kernels.quantization.nvfp4_utils
import
(
from
tests.kernels.quantization.nvfp4_utils
import
(
FLOAT4_E2M1_MAX
,
FLOAT4_E2M1_MAX
,
FLOAT8_E4M3_MAX
,
FLOAT8_E4M3_MAX
,
...
@@ -92,8 +92,7 @@ def test_cutlass_fp4_moe_no_graph(
...
@@ -92,8 +92,7 @@ def test_cutlass_fp4_moe_no_graph(
kernel
=
mk
.
FusedMoEModularKernel
(
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
CutlassExpertsFp4
(
CutlassExpertsFp4
(
out_dtype
=
dtype
,
moe_config
=
make_dummy_moe_config
(),
max_experts_per_worker
=
e
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
),
),
)
)
...
...
tests/kernels/moe/test_pplx_cutlass_moe.py
View file @
42135d68
...
@@ -9,12 +9,18 @@ from tests.kernels.utils import torch_experts
...
@@ -9,12 +9,18 @@ from tests.kernels.utils import torch_experts
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.config
import
fp8_w8a8_moe_quant_config
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
RoutingMethodType
,
fp8_w8a8_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassBatchedExpertsFp8
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassBatchedExpertsFp8
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.worker.workspace
import
init_workspace_manager
from
...utils
import
multi_gpu_test
from
...utils
import
multi_gpu_test
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
...
@@ -79,6 +85,8 @@ def pplx_cutlass_moe(
...
@@ -79,6 +85,8 @@ def pplx_cutlass_moe(
PplxPrepareAndFinalize
,
PplxPrepareAndFinalize
,
)
)
init_workspace_manager
(
torch
.
cuda
.
current_device
())
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
num_tokens
,
hidden_dim
=
a
.
shape
num_tokens
,
hidden_dim
=
a
.
shape
...
@@ -132,28 +140,23 @@ def pplx_cutlass_moe(
...
@@ -132,28 +140,23 @@ def pplx_cutlass_moe(
num_dispatchers
=
num_dispatchers
,
num_dispatchers
=
num_dispatchers
,
)
)
ab_strides1
=
torch
.
full
(
def
make_moe_config
()
->
FusedMoEConfig
:
(
num_local_experts
,),
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
return
FusedMoEConfig
(
)
num_experts
=
num_experts
,
ab_strides2
=
torch
.
full
(
experts_per_token
=
topk
,
(
num_local_experts
,),
intermediate_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
hidden_dim
=
hidden_dim
,
)
intermediate_size_per_partition
=
intermediate_dim
,
c_strides1
=
torch
.
full
(
num_local_experts
=
num_local_experts
,
(
num_local_experts
,),
2
*
intermediate_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
moe_parallel_config
=
FusedMoEParallelConfig
.
make_no_parallel
(),
)
activation
=
"silu"
,
c_strides2
=
torch
.
full
(
in_dtype
=
torch
.
bfloat16
,
(
num_local_experts
,),
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
device
=
"cuda"
,
)
routing_method
=
RoutingMethodType
.
Llama4
,
)
experts
=
CutlassBatchedExpertsFp8
(
experts
=
CutlassBatchedExpertsFp8
(
num_local_experts
,
moe_config
=
make_moe_config
(),
num_dispatchers
,
quant_config
=
fp8_w8a8_moe_quant_config
(
out_dtype
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
fp8_w8a8_moe_quant_config
(
per_act_token_quant
=
per_act_token
,
per_act_token_quant
=
per_act_token
,
per_out_ch_quant
=
per_out_ch
,
per_out_ch_quant
=
per_out_ch
,
w1_scale
=
chunk_by_rank
(
w1_scale
,
rank
,
world_size
),
w1_scale
=
chunk_by_rank
(
w1_scale
,
rank
,
world_size
),
...
@@ -162,6 +165,8 @@ def pplx_cutlass_moe(
...
@@ -162,6 +165,8 @@ def pplx_cutlass_moe(
if
per_act_token
if
per_act_token
else
a1_scale
[
rank
],
else
a1_scale
[
rank
],
),
),
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
)
fused_cutlass_experts
=
FusedMoEModularKernel
(
fused_cutlass_experts
=
FusedMoEModularKernel
(
...
...
tests/kernels/moe/test_pplx_moe.py
View file @
42135d68
...
@@ -29,6 +29,7 @@ except ImportError:
...
@@ -29,6 +29,7 @@ except ImportError:
from
tests.kernels.moe.modular_kernel_tools.parallel_utils
import
_set_vllm_config
from
tests.kernels.moe.modular_kernel_tools.parallel_utils
import
_set_vllm_config
from
tests.kernels.moe.utils
import
(
from
tests.kernels.moe.utils
import
(
make_dummy_moe_config
,
make_shared_experts
,
make_shared_experts
,
make_test_weights
,
make_test_weights
,
naive_batched_moe
,
naive_batched_moe
,
...
@@ -584,6 +585,7 @@ def pplx_moe(
...
@@ -584,6 +585,7 @@ def pplx_moe(
max_num_tokens
=
max_num_tokens
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
quant_config
,
quant_config
=
quant_config
,
moe_config
=
make_dummy_moe_config
(),
)
)
fused_experts
=
FusedMoEModularKernel
(
fused_experts
=
FusedMoEModularKernel
(
...
...
tests/kernels/moe/test_routing.py
View file @
42135d68
...
@@ -6,7 +6,6 @@ import pytest
...
@@ -6,7 +6,6 @@ import pytest
import
torch
import
torch
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.model_executor.layers.fused_moe.config
import
RoutingMethodType
from
vllm.model_executor.layers.fused_moe.router.router_factory
import
(
from
vllm.model_executor.layers.fused_moe.router.router_factory
import
(
create_fused_moe_router
,
create_fused_moe_router
,
)
)
...
@@ -385,17 +384,11 @@ def test_grouped_topk(
...
@@ -385,17 +384,11 @@ def test_grouped_topk(
global_num_experts
,
global_num_experts
,
)
)
routing_method_type
=
None
if
scoring_func
==
"llama4"
:
routing_method_type
=
RoutingMethodType
.
Llama4
scoring_func
=
"sigmoid"
router
=
create_fused_moe_router
(
router
=
create_fused_moe_router
(
use_grouped_topk
=
True
,
use_grouped_topk
=
True
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
routing_method_type
=
routing_method_type
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
top_k
=
top_k
,
top_k
=
top_k
,
...
...
tests/kernels/moe/test_triton_moe_no_act_mul.py
View file @
42135d68
...
@@ -10,6 +10,7 @@ equals N (not N // 2 like gated activations).
...
@@ -10,6 +10,7 @@ equals N (not N // 2 like gated activations).
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.moe.utils
import
make_dummy_moe_config
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
FUSED_MOE_UNQUANTIZED_CONFIG
,
)
)
...
@@ -78,7 +79,10 @@ def test_triton_experts_no_mul_activation(
...
@@ -78,7 +79,10 @@ def test_triton_experts_no_mul_activation(
m
,
n
,
k
,
NUM_EXPERTS
,
topk
m
,
n
,
k
,
NUM_EXPERTS
,
topk
)
)
experts
=
TritonExperts
(
FUSED_MOE_UNQUANTIZED_CONFIG
)
experts
=
TritonExperts
(
moe_config
=
make_dummy_moe_config
(),
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
,
)
ws1_shape
,
ws2_shape
,
out_shape
=
experts
.
workspace_shapes
(
ws1_shape
,
ws2_shape
,
out_shape
=
experts
.
workspace_shapes
(
M
=
m
,
M
=
m
,
...
@@ -151,7 +155,10 @@ def test_workspace_shapes_no_mul_vs_gated():
...
@@ -151,7 +155,10 @@ def test_workspace_shapes_no_mul_vs_gated():
M
,
N
,
K
,
topk
=
64
,
256
,
128
,
2
M
,
N
,
K
,
topk
=
64
,
256
,
128
,
2
experts
=
TritonExperts
(
FUSED_MOE_UNQUANTIZED_CONFIG
)
experts
=
TritonExperts
(
moe_config
=
make_dummy_moe_config
(),
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
,
)
ws1_no_mul
,
_
,
out_no_mul
=
experts
.
workspace_shapes
(
ws1_no_mul
,
_
,
out_no_mul
=
experts
.
workspace_shapes
(
M
,
N
,
K
,
topk
,
8
,
8
,
None
,
SILU_NO_MUL
M
,
N
,
K
,
topk
,
8
,
8
,
None
,
SILU_NO_MUL
...
@@ -187,7 +194,10 @@ def test_adjust_n_for_activation():
...
@@ -187,7 +194,10 @@ def test_adjust_n_for_activation():
"""Test the adjust_N_for_activation method."""
"""Test the adjust_N_for_activation method."""
from
vllm.model_executor.layers.fused_moe.fused_moe
import
TritonExperts
from
vllm.model_executor.layers.fused_moe.fused_moe
import
TritonExperts
experts
=
TritonExperts
(
FUSED_MOE_UNQUANTIZED_CONFIG
)
experts
=
TritonExperts
(
moe_config
=
make_dummy_moe_config
(),
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
,
)
N
=
256
N
=
256
...
...
tests/kernels/moe/utils.py
View file @
42135d68
...
@@ -8,7 +8,12 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
...
@@ -8,7 +8,12 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from
tests.kernels.quantization.nvfp4_utils
import
FLOAT4_E2M1_MAX
,
FLOAT8_E4M3_MAX
from
tests.kernels.quantization.nvfp4_utils
import
FLOAT4_E2M1_MAX
,
FLOAT8_E4M3_MAX
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
)
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
BatchedPrepareAndFinalize
,
BatchedPrepareAndFinalize
,
BatchedTritonExperts
,
BatchedTritonExperts
,
...
@@ -20,6 +25,34 @@ from vllm.utils.deep_gemm import per_block_cast_to_fp8
...
@@ -20,6 +25,34 @@ from vllm.utils.deep_gemm import per_block_cast_to_fp8
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.math_utils
import
round_up
def
make_dummy_moe_config
(
num_experts
:
int
=
1
,
experts_per_token
:
int
=
1
,
hidden_dim
:
int
=
1
,
intermediate_size_per_partition
:
int
=
1
,
in_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
FusedMoEConfig
:
"""
This is a dummy config for the mk constructor interface
as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
do not actually use this config.
CUTLASSFp8 needs to set some params for workshapes.
"""
return
FusedMoEConfig
(
num_experts
=
num_experts
,
experts_per_token
=
experts_per_token
,
hidden_dim
=
hidden_dim
,
intermediate_size_per_partition
=
intermediate_size_per_partition
,
num_local_experts
=
num_experts
,
moe_parallel_config
=
FusedMoEParallelConfig
.
make_no_parallel
(),
activation
=
"silu"
,
in_dtype
=
in_dtype
,
device
=
"cuda"
,
routing_method
=
RoutingMethodType
.
TopK
,
)
def
triton_moe
(
def
triton_moe
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -81,6 +114,7 @@ def batched_moe(
...
@@ -81,6 +114,7 @@ def batched_moe(
max_num_tokens
=
max_num_tokens
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
1
,
num_dispatchers
=
1
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
moe_config
=
make_dummy_moe_config
(),
),
),
)
)
...
@@ -121,6 +155,7 @@ def naive_batched_moe(
...
@@ -121,6 +155,7 @@ def naive_batched_moe(
max_num_tokens
=
max_num_tokens
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
1
,
num_dispatchers
=
1
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
moe_config
=
make_dummy_moe_config
(),
),
),
)
)
...
...
tools/vllm-rocm/pin_rocm_dependencies.py
View file @
42135d68
...
@@ -11,10 +11,11 @@ This ensures that 'pip install vllm' automatically installs the correct custom w
...
@@ -11,10 +11,11 @@ This ensures that 'pip install vllm' automatically installs the correct custom w
instead of allowing pip to download different versions from PyPI.
instead of allowing pip to download different versions from PyPI.
"""
"""
import
re
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
import
regex
as
re
def
extract_version_from_wheel
(
wheel_name
:
str
)
->
str
:
def
extract_version_from_wheel
(
wheel_name
:
str
)
->
str
:
"""
"""
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment