Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
09e4576f
Unverified
Commit
09e4576f
authored
Mar 17, 2026
by
Michael Goin
Committed by
GitHub
Mar 17, 2026
Browse files
[Kernel] Add non-gated support for NVFP4 CUTLASS MoE (#37320)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
3ed7b1e6
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
53 additions
and
26 deletions
+53
-26
csrc/ops.h
csrc/ops.h
+2
-1
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
+15
-13
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
+5
-3
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-2
tests/evals/gsm8k/configs/moe-refactor/Nemotron-Nano-30B-NvFp4-ModelOpt-vllm-cutlass.yaml
...factor/Nemotron-Nano-30B-NvFp4-ModelOpt-vllm-cutlass.yaml
+5
-0
tests/evals/gsm8k/configs/moe-refactor/config-b200.txt
tests/evals/gsm8k/configs/moe-refactor/config-b200.txt
+1
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+4
-0
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+19
-7
No files found.
csrc/ops.h
View file @
09e4576f
...
...
@@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data(
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
const
bool
is_gated
);
void
get_cutlass_moe_mm_problem_sizes_from_expert_offsets
(
const
torch
::
Tensor
&
expert_first_token_offset
,
...
...
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
View file @
09e4576f
...
...
@@ -17,8 +17,11 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t
*
problem_sizes2
,
int32_t
*
atomic_buffer
,
const
int
topk_length
,
const
int
n
,
const
int
k
)
{
const
int
k
,
const
bool
is_gated
)
{
int
expert_id
=
blockIdx
.
x
;
// For gated activations (gate + up), first GEMM output is 2*n.
// For non-gated activations (up only), first GEMM output is n.
int
const
n1
=
is_gated
?
2
*
n
:
n
;
int
occurrences
=
0
;
for
(
int
i
=
threadIdx
.
x
;
i
<
topk_length
;
i
+=
THREADS_PER_EXPERT
)
{
...
...
@@ -31,13 +34,13 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int
final_occurrences
=
atomic_buffer
[
expert_id
];
if
constexpr
(
!
SWAP_AB
)
{
problem_sizes1
[
expert_id
*
3
]
=
final_occurrences
;
problem_sizes1
[
expert_id
*
3
+
1
]
=
2
*
n
;
problem_sizes1
[
expert_id
*
3
+
1
]
=
n
1
;
problem_sizes1
[
expert_id
*
3
+
2
]
=
k
;
problem_sizes2
[
expert_id
*
3
]
=
final_occurrences
;
problem_sizes2
[
expert_id
*
3
+
1
]
=
k
;
problem_sizes2
[
expert_id
*
3
+
2
]
=
n
;
}
else
{
problem_sizes1
[
expert_id
*
3
]
=
2
*
n
;
problem_sizes1
[
expert_id
*
3
]
=
n
1
;
problem_sizes1
[
expert_id
*
3
+
1
]
=
final_occurrences
;
problem_sizes1
[
expert_id
*
3
+
2
]
=
k
;
problem_sizes2
[
expert_id
*
3
]
=
k
;
...
...
@@ -107,13 +110,11 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
}
namespace
{
inline
void
launch_compute_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
atomic_buffer
,
int64_t
num_experts
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
,
const
bool
swap_ab
)
{
inline
void
launch_compute_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
atomic_buffer
,
int64_t
num_experts
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
,
const
bool
swap_ab
,
const
bool
is_gated
)
{
int
num_threads
=
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
());
auto
const
*
topk_ptr
=
topk_ids
.
data_ptr
<
int32_t
>
();
...
...
@@ -125,7 +126,7 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
compute_problem_sizes
<
SwapAB
><<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
topk_ptr
,
ps1_ptr
,
ps2_ptr
,
atomic_ptr
,
static_cast
<
int
>
(
topk_ids
.
numel
()),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
));
static_cast
<
int
>
(
k
)
,
is_gated
);
});
}
}
// namespace
...
...
@@ -222,7 +223,8 @@ void get_cutlass_moe_mm_data_caller(
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
const
bool
is_gated
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
...
...
@@ -236,7 +238,7 @@ void get_cutlass_moe_mm_data_caller(
launch_compute_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
atomic_buffer
,
num_experts
,
n
,
k
,
stream
,
may_swap_ab
);
may_swap_ab
,
is_gated
);
if
(
blockscale_offsets
.
has_value
())
{
// fp4 path
...
...
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
View file @
09e4576f
...
...
@@ -75,7 +75,8 @@ void get_cutlass_moe_mm_data_caller(
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
const
bool
is_gated
);
void
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller
(
const
torch
::
Tensor
&
expert_first_token_offset
,
...
...
@@ -278,7 +279,8 @@ void get_cutlass_moe_mm_data(
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
const
bool
is_gated
)
{
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
...
...
@@ -288,7 +290,7 @@ void get_cutlass_moe_mm_data(
get_cutlass_moe_mm_data_caller
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
,
blockscale_offsets
);
blockscale_offsets
,
is_gated
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
...
...
csrc/torch_bindings.cpp
View file @
09e4576f
...
...
@@ -489,8 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k, Tensor? blockscale_offsets
) ->
"
"()"
);
" int n, int k, Tensor? blockscale_offsets
,
"
"
bool is_gated) ->
()"
);
ops
.
impl
(
"get_cutlass_moe_mm_data"
,
torch
::
kCUDA
,
&
get_cutlass_moe_mm_data
);
// compute per-expert problem sizes from expert_first_token_offset
...
...
tests/evals/gsm8k/configs/moe-refactor/Nemotron-Nano-30B-NvFp4-ModelOpt-vllm-cutlass.yaml
0 → 100644
View file @
09e4576f
model_name
:
"
nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4"
accuracy_threshold
:
0.29
num_questions
:
1319
num_fewshot
:
5
server_args
:
"
--enforce-eager
--max-model-len
8192
--tensor-parallel-size
2
--moe-backend=cutlass"
tests/evals/gsm8k/configs/moe-refactor/config-b200.txt
View file @
09e4576f
...
...
@@ -15,3 +15,4 @@ Mixtral-8x7B-BF16-fi-cutlass.yaml
Mixtral-8x7B-BF16-triton.yaml
Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml
Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml
Nemotron-Nano-30B-NvFp4-ModelOpt-vllm-cutlass.yaml
vllm/_custom_ops.py
View file @
09e4576f
...
...
@@ -989,6 +989,7 @@ def get_cutlass_moe_mm_data(
n
:
int
,
k
:
int
,
blockscale_offsets
:
torch
.
Tensor
|
None
=
None
,
is_gated
:
bool
=
True
,
):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
...
...
@@ -1012,6 +1013,8 @@ def get_cutlass_moe_mm_data(
its computation. The number of block scale rows
computed with expert E is blockscale_offsets[E + 1] -
blockscale_offsets[E]
- is_gated: Whether the activation is gated (gate + up). When True, the
first GEMM N dimension is 2*n; when False, it is n.
"""
return
torch
.
ops
.
_C
.
get_cutlass_moe_mm_data
(
topk_ids
,
...
...
@@ -1024,6 +1027,7 @@ def get_cutlass_moe_mm_data(
n
,
k
,
blockscale_offsets
,
is_gated
,
)
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
09e4576f
...
...
@@ -507,11 +507,12 @@ def run_cutlass_moe_fp4(
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
w1 (not an argument to cutlass_moe_fp4): [e, w1_n, k]
w1_fp4: [e, w1_n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
where w1_n = 2*n for gated activations (gate+up), n for non-gated (up only).
(Note: `n` is the up projection output dim, `k` is the input dim in
full precision)
w1_blockscale: [e,
2 *
n, k // block_size] (float8_e4m3)
w1_blockscale: [e,
w1_
n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
...
...
@@ -528,6 +529,11 @@ def run_cutlass_moe_fp4(
assumes that topk < k < n to satisfy - up/down projection expectations.
"""
is_gated
=
activation
.
is_gated
# For gated activations (e.g. SiLU), w1 output is 2*n (gate + up).
# For non-gated activations (e.g. SiLU_NO_MUL), w1 output is n (up only).
w1_n
=
n
*
2
if
is_gated
else
n
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
w1_fp4
.
dtype
==
torch
.
uint8
,
"weight 1 must be uint8"
assert
w2_fp4
.
dtype
==
torch
.
uint8
,
"weight 2 must be uint8"
...
...
@@ -538,7 +544,7 @@ def run_cutlass_moe_fp4(
and
w2_blockscale
.
ndim
==
3
),
"All Weights must be of rank 3 for cutlass_moe_fp4"
m_a
,
k_a
=
a
.
shape
e_w1
,
nx2_w1
,
half_k_w1
=
w1_fp4
.
shape
e_w1
,
w1_n_actual
,
half_k_w1
=
w1_fp4
.
shape
e_w2
,
k_w2
,
half_n_w2
=
w2_fp4
.
shape
assert
e_w1
==
e_w2
and
e_w1
==
e
,
(
...
...
@@ -548,7 +554,7 @@ def run_cutlass_moe_fp4(
assert
k_a
==
half_k_w1
*
2
and
k
==
k_w2
,
(
"Hidden size mismatch between a, w1 and w2"
)
assert
nx2_w1
==
n
*
2
and
half_n_w2
*
2
==
n
,
"mismatch in expected `n`"
assert
w1_n_actual
==
w1_n
and
half_n_w2
*
2
==
n
,
"mismatch in expected `n`"
assert
m
==
m_a
,
"input shape mismatch"
assert
2
*
half_k_w1
==
k_w2
,
"Hidden size mismatch w2 and w1"
assert
a
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid input dtype"
...
...
@@ -589,6 +595,7 @@ def run_cutlass_moe_fp4(
n
,
k
,
blockscale_offsets
,
is_gated
=
is_gated
,
)
a
=
ops
.
shuffle_rows
(
a
,
a_map
)
...
...
@@ -599,7 +606,7 @@ def run_cutlass_moe_fp4(
blockscale_offsets
,
num_topk
,
)
c1
=
_resize_cache
(
workspace13
,
(
m
*
topk
,
n
*
2
))
c1
=
_resize_cache
(
workspace13
,
(
m
*
topk
,
w1_n
))
c2
=
_resize_cache
(
workspace2
,
(
m
*
topk
,
n
))
c3
=
_resize_cache
(
workspace13
,
(
m
*
topk
,
k
))
ops
.
cutlass_fp4_moe_mm
(
...
...
@@ -681,7 +688,7 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
Fals
e
return
Tru
e
@
staticmethod
def
_supports_quant_scheme
(
...
...
@@ -695,11 +702,16 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
# SILU uses a fused silu+mul+fp4_quant kernel path.
# Other gated activations use the generic apply_moe_activation()
# fallback + separate fp4 quantization in run_cutlass_moe_fp4().
# Non-gated activations (_NO_MUL) are also supported for models
# like Nemotron-Nano that don't use gated MLP.
return
activation
in
[
MoEActivation
.
SILU
,
MoEActivation
.
GELU
,
MoEActivation
.
SWIGLUOAI
,
MoEActivation
.
SWIGLUSTEP
,
MoEActivation
.
SILU_NO_MUL
,
MoEActivation
.
GELU_NO_MUL
,
MoEActivation
.
RELU2_NO_MUL
,
]
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment