Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9fb2d220
Unverified
Commit
9fb2d220
authored
Jul 17, 2025
by
ElizaWszola
Committed by
GitHub
Jul 17, 2025
Browse files
[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE (#20762)
Signed-off-by:
ElizaWszola
<
ewszola@redhat.com
>
parent
2d6a3820
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
174 additions
and
38 deletions
+174
-38
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+34
-1
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+42
-11
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+12
-2
tests/kernels/moe/test_pplx_cutlass_moe.py
tests/kernels/moe/test_pplx_cutlass_moe.py
+22
-0
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+39
-23
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+25
-1
No files found.
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
View file @
9fb2d220
...
...
@@ -80,6 +80,11 @@ def bench_run(
a
,
score
,
topk
,
renormalize
=
False
)
ab_strides1
=
torch
.
full
((
num_experts
,),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
num_experts
,),
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
num_experts
,),
2
*
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
num_experts
,),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
def
run_triton_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
@@ -111,6 +116,10 @@ def bench_run(
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
per_act_token
:
bool
,
...
...
@@ -125,6 +134,10 @@ def bench_run(
topk_ids
,
w1_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
per_act_token
,
a1_scale
=
None
,
)
...
...
@@ -136,6 +149,10 @@ def bench_run(
w2_q
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
):
...
...
@@ -150,6 +167,10 @@ def bench_run(
topk_ids
,
w1_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
per_act_token
,
a1_scale
=
None
,
)
...
...
@@ -194,6 +215,10 @@ def bench_run(
w2_q
,
w1_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
topk_weights
,
topk_ids
,
)
...
...
@@ -231,6 +256,10 @@ def bench_run(
"w1_scale"
:
w1_scale
,
"w2_scale"
:
w2_scale
,
"per_act_token"
:
per_act_token
,
"ab_strides1"
:
ab_strides1
,
"ab_strides2"
:
ab_strides2
,
"c_strides1"
:
c_strides1
,
"c_strides2"
:
c_strides2
,
# cuda graph params
"cutlass_graph"
:
cutlass_graph
,
"triton_graph"
:
triton_graph
,
...
...
@@ -289,6 +318,10 @@ def bench_run(
w2_q
,
w1_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
topk_weights
,
topk_ids
,
per_act_token
,
...
...
@@ -297,7 +330,7 @@ def bench_run(
results
.
append
(
benchmark
.
Timer
(
stmt
=
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)"
,
# noqa: E501
stmt
=
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
ab_strides1, ab_strides2, c_strides1, c_strides2,
topk_weights, topk_ids, per_act_token, num_runs)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
...
...
csrc/moe/moe_permute_unpermute_op.cu
View file @
9fb2d220
...
...
@@ -160,6 +160,30 @@ __global__ void shuffleInputRowsKernel(const T* input,
}
}
template
<
typename
T
>
__global__
void
shuffleInputRowsKernelSlow
(
const
T
*
input
,
const
int32_t
*
dst2src_map
,
T
*
output
,
int64_t
num_src_rows
,
int64_t
num_dst_rows
,
int64_t
num_cols
)
{
int64_t
dest_row_idx
=
blockIdx
.
x
;
int64_t
const
source_row_idx
=
dst2src_map
[
dest_row_idx
];
if
(
blockIdx
.
x
<
num_dst_rows
)
{
// Duplicate and permute rows
auto
const
*
source_row_ptr
=
input
+
source_row_idx
*
num_cols
;
auto
*
dest_row_ptr
=
output
+
dest_row_idx
*
num_cols
;
int64_t
const
start_offset
=
threadIdx
.
x
;
int64_t
const
stride
=
blockDim
.
x
;
for
(
int
elem_index
=
start_offset
;
elem_index
<
num_cols
;
elem_index
+=
stride
)
{
dest_row_ptr
[
elem_index
]
=
source_row_ptr
[
elem_index
];
}
}
}
void
shuffle_rows
(
const
torch
::
Tensor
&
input_tensor
,
const
torch
::
Tensor
&
dst2src_map
,
torch
::
Tensor
&
output_tensor
)
{
...
...
@@ -173,10 +197,16 @@ void shuffle_rows(const torch::Tensor& input_tensor,
int64_t
const
num_src_rows
=
input_tensor
.
size
(
0
);
int64_t
const
num_cols
=
input_tensor
.
size
(
1
);
TORCH_CHECK
(
!
(
num_cols
%
(
128
/
sizeof
(
input_tensor
.
scalar_type
())
/
8
)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8"
);
if
(
num_cols
%
(
128
/
sizeof
(
input_tensor
.
scalar_type
())
/
8
))
{
// use slow kernel if num_cols can't be aligned to 128 bits
MOE_DISPATCH
(
input_tensor
.
scalar_type
(),
[
&
]
{
shuffleInputRowsKernelSlow
<
scalar_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
reinterpret_cast
<
scalar_t
*>
(
input_tensor
.
data_ptr
()),
dst2src_map
.
data_ptr
<
int32_t
>
(),
reinterpret_cast
<
scalar_t
*>
(
output_tensor
.
data_ptr
()),
num_src_rows
,
num_dest_rows
,
num_cols
);
});
}
else
{
MOE_DISPATCH
(
input_tensor
.
scalar_type
(),
[
&
]
{
shuffleInputRowsKernel
<
scalar_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
reinterpret_cast
<
scalar_t
*>
(
input_tensor
.
data_ptr
()),
...
...
@@ -184,6 +214,7 @@ void shuffle_rows(const torch::Tensor& input_tensor,
reinterpret_cast
<
scalar_t
*>
(
output_tensor
.
data_ptr
()),
num_src_rows
,
num_dest_rows
,
num_cols
);
});
}
}
#else
...
...
tests/kernels/moe/test_cutlass_moe.py
View file @
9fb2d220
...
...
@@ -206,6 +206,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids'
:
topk_ids
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'c_strides2'
:
moe_tensors
.
c_strides2
,
'per_act_token'
:
per_act_token
,
'a1_scale'
:
None
#moe_tensors.a_scale
}
...
...
@@ -439,6 +443,11 @@ def test_run_cutlass_moe_fp8(
expert_map
[
start
:
end
]
=
list
(
range
(
num_local_experts
))
expert_map
=
torch
.
tensor
(
expert_map
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
ab_strides1
=
torch
.
full
((
e
,
),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
e
,
),
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
e
,
),
2
*
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
e
,
),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
activation
=
lambda
o
,
i
:
torch
.
ops
.
_C
.
silu_and_mul
(
o
,
i
)
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
mt
.
a
,
mt
.
a_scale
,
torch
.
float8_e4m3fn
,
...
...
@@ -447,8 +456,9 @@ def test_run_cutlass_moe_fp8(
func
=
lambda
output
:
run_cutlass_moe_fp8
(
output
,
a1q
,
mt
.
w1_q
,
mt
.
w2_q
,
topk_ids
,
activation
,
global_num_experts
,
expert_map
,
mt
.
w1_scale
,
mt
.
w2_scale
,
a1q_scale
,
None
,
workspace13
,
workspace2
,
None
,
mt
.
a
.
dtype
,
per_act_token
,
per_out_channel
,
False
)
a1q_scale
,
None
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
workspace13
,
workspace2
,
None
,
mt
.
a
.
dtype
,
per_act_token
,
per_out_channel
,
False
)
workspace13
.
random_
()
output_random_workspace
=
torch
.
empty
(
output_shape
,
...
...
tests/kernels/moe/test_pplx_cutlass_moe.py
View file @
9fb2d220
...
...
@@ -75,6 +75,7 @@ def pplx_cutlass_moe(
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
num_tokens
,
hidden_dim
=
a
.
shape
intermediate_dim
=
w2
.
shape
[
2
]
num_experts
=
w1
.
shape
[
0
]
block_size
=
hidden_dim
# TODO support more cases
device
=
pgi
.
device
...
...
@@ -123,10 +124,31 @@ def pplx_cutlass_moe(
num_local_experts
=
num_local_experts
,
num_dispatchers
=
num_dispatchers
)
ab_strides1
=
torch
.
full
((
num_local_experts
,
),
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
num_local_experts
,
),
intermediate_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
num_local_experts
,
),
2
*
intermediate_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
num_local_experts
,
),
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
experts
=
CutlassExpertsFp8
(
num_local_experts
,
out_dtype
,
per_act_token
,
per_out_ch
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
num_dispatchers
=
num_dispatchers
,
use_batched_format
=
True
)
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
9fb2d220
...
...
@@ -13,8 +13,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_perm
,
_fp8_quantize
,
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_quantize
,
_resize_cache
)
from
vllm.scalar_type
import
scalar_types
...
...
@@ -34,6 +33,10 @@ def run_cutlass_moe_fp8(
w2_scale
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
...
...
@@ -152,27 +155,11 @@ def run_cutlass_moe_fp8(
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
global_num_experts
,
N
,
K
)
a1q
=
_fp8_perm
(
a1q
,
a_map
)
a1q_scale
=
a1q_scale
[
a_map
]
if
per_act_token
else
a1q_scale
a1q
=
ops
.
shuffle_rows
(
a1q
,
a_map
)
a1q_scale
=
(
ops
.
shuffle_rows
(
a1q_scale
,
a_map
)
if
per_act_token
else
a1q_scale
)
expert_offsets
=
expert_offsets
[:
-
1
]
ab_strides1
=
torch
.
full
((
w1
.
size
(
0
),
),
K
,
device
=
device
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
w1
.
size
(
0
),
),
2
*
N
,
device
=
device
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
w1
.
size
(
0
),
),
N
,
device
=
device
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
w1
.
size
(
0
),
),
K
,
device
=
device
,
dtype
=
torch
.
int64
)
if
use_batched_format
:
c1
=
_resize_cache
(
workspace13
,
(
local_E
*
padded_M
,
N
*
2
))
c2
=
_resize_cache
(
workspace2
,
(
local_E
*
padded_M
,
N
))
...
...
@@ -209,7 +196,8 @@ def run_cutlass_moe_fp8(
else
:
# We can't do this inplace because output may point to the same tensor
# as c3.
output
.
copy_
(
c3
[
c_map
].
view
(
M
*
topk
,
K
),
non_blocking
=
True
)
output
.
copy_
(
ops
.
shuffle_rows
(
c3
,
c_map
).
view
(
M
*
topk
,
K
),
non_blocking
=
True
)
# TODO (bnell): split class batched vs. non-batched?
...
...
@@ -222,6 +210,10 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
per_out_ch_quant
:
bool
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
num_dispatchers
:
Optional
[
int
]
=
None
,
use_batched_format
:
bool
=
False
,
...
...
@@ -238,6 +230,10 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
self
.
max_experts_per_worker
=
max_experts_per_worker
self
.
num_dispatchers
=
num_dispatchers
self
.
out_dtype
=
out_dtype
self
.
ab_strides1
=
ab_strides1
self
.
ab_strides2
=
ab_strides2
self
.
c_strides1
=
c_strides1
self
.
c_strides2
=
c_strides2
self
.
use_batched_format
=
use_batched_format
@
property
...
...
@@ -316,7 +312,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8
(
output
,
hidden_states
,
w1
,
w2
,
topk_ids
,
activation_callable
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
a1q_scale
,
a2_scale
,
workspace13
,
workspace2
,
expert_num_tokens
,
a2_scale
,
self
.
ab_strides1
,
self
.
ab_strides2
,
self
.
c_strides1
,
self
.
c_strides2
,
workspace13
,
workspace2
,
expert_num_tokens
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
in_dtype
,
self
.
per_act_token_quant
,
self
.
per_out_ch_quant
,
self
.
use_batched_format
)
...
...
@@ -330,6 +327,10 @@ def cutlass_moe_fp8(
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
per_act_token
:
Optional
[
bool
]
=
None
,
activation
:
str
=
"silu"
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -357,6 +358,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
...
...
@@ -389,6 +401,10 @@ def cutlass_moe_fp8(
out_dtype
=
a
.
dtype
,
per_act_token_quant
=
per_act_token
,
per_out_ch_quant
=
per_out_ch
,
ab_strides1
=
ab_strides1
,
ab_strides2
=
ab_strides2
,
c_strides1
=
c_strides1
,
c_strides2
=
c_strides2
,
use_batched_format
=
False
,
),
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
9fb2d220
...
...
@@ -859,6 +859,21 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
device
=
layer
.
w13_weight
.
device
# ab_strides1 and c_strides2 are the same
self
.
ab_strides1_c_strides2
=
torch
.
full
((
layer
.
local_num_experts
,
),
layer
.
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
ab_strides2
=
torch
.
full
((
layer
.
local_num_experts
,
),
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
c_strides1
=
torch
.
full
((
layer
.
local_num_experts
,
),
2
*
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
...
...
@@ -881,6 +896,10 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
moe
.
in_dtype
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
num_dispatchers
=
num_dispatchers
,
use_batched_format
=
use_batched_format
,
)
...
...
@@ -927,7 +946,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
)
per_act_token
=
(
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
...
...
@@ -948,6 +968,10 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment