Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d4a93841
Unverified
Commit
d4a93841
authored
Sep 02, 2025
by
chenxj
Committed by
GitHub
Sep 01, 2025
Browse files
[feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118)
Co-authored-by:
yuhyao
<
827623970@qq.com
>
parent
21e1bc47
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
291 additions
and
120 deletions
+291
-120
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+2
-1
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
+1
-9
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+0
-3
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+5
-2
python/sglang/srt/layers/quantization/w4afp8.py
python/sglang/srt/layers/quantization/w4afp8.py
+30
-25
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-0
python/sglang/test/test_cutlass_w4a8_moe.py
python/sglang/test/test_cutlass_w4a8_moe.py
+24
-9
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh
...ernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh
+1
-1
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
+206
-60
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
+7
-6
sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py
sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py
+10
-4
No files found.
python/sglang/srt/configs/model_config.py
View file @
d4a93841
...
...
@@ -405,9 +405,10 @@ class ModelConfig:
# compressed-tensors uses a "compression_config" key
quant_cfg
=
getattr
(
self
.
hf_config
,
"compression_config"
,
None
)
if
quant_cfg
is
None
:
# check if is modelopt
model -- modelopt
do
es
n't have corresponding field
# check if is modelopt
or mixed-precision model -- Both of them
don't have corresponding field
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
is_local
=
os
.
path
.
exists
(
self
.
model_path
)
modelopt_quant_config
=
{
"quant_method"
:
"modelopt"
}
if
not
is_local
:
...
...
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
View file @
d4a93841
...
...
@@ -91,18 +91,10 @@ def cutlass_w4a8_moe(
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
(
w1_scale
.
shape
[
1
]
==
w1_q
.
shape
[
2
]
*
2
/
512
and
w1_scale
.
shape
[
2
]
==
w1_q
.
shape
[
1
]
*
4
),
"W1 scale shape mismatch"
assert
(
w2_scale
.
shape
[
1
]
==
w2_q
.
shape
[
2
]
*
2
/
512
and
w2_scale
.
shape
[
2
]
==
w2_q
.
shape
[
1
]
*
4
),
"W2 scale shape mismatch"
assert
a_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"A Strides 1 expert number mismatch"
assert
b_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"B Strides 1 expert number mismatch"
assert
a_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"A Strides 2 expert number
mismatch"
assert
a_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"A Strides 2 expert number mismatch"
assert
b_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"B Strides 2 expert number mismatch"
num_experts
=
w1_q
.
size
(
0
)
m
=
a
.
size
(
0
)
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
d4a93841
...
...
@@ -114,9 +114,6 @@ class EPMoE(FusedMoE):
with_bias
=
with_bias
,
)
self
.
start_expert_id
=
self
.
moe_ep_rank
*
self
.
num_local_experts
self
.
end_expert_id
=
self
.
start_expert_id
+
self
.
num_local_experts
-
1
self
.
intermediate_size
=
intermediate_size
if
isinstance
(
quant_config
,
Fp8Config
):
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
d4a93841
...
...
@@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module):
self
.
moe_tp_rank
=
get_moe_tensor_parallel_rank
()
assert
num_experts
%
self
.
moe_ep_size
==
0
self
.
num_local_experts
=
num_experts
//
self
.
moe_ep_size
self
.
start_expert_id
=
self
.
moe_ep_rank
*
self
.
num_local_experts
self
.
end_expert_id
=
self
.
start_expert_id
+
self
.
num_local_experts
-
1
if
self
.
moe_ep_size
>
1
:
# TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1
...
...
@@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module):
if
(
"compressed"
in
self
.
quant_method
.
__class__
.
__name__
.
lower
()
and
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
or
"w4afp8"
in
self
.
quant_config
.
get_name
()
and
(
param
.
data
[
expert_id
]
!=
1
).
any
()
and
((
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
).
any
()
):
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
...
...
python/sglang/srt/layers/quantization/w4afp8.py
View file @
d4a93841
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
QuantizationConfig
,
...
...
@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
return
UnquantizedLinearMethod
()
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
EP
MoE
):
elif
isinstance
(
layer
,
Fused
MoE
):
return
W4AFp8MoEMethod
(
self
)
return
None
...
...
@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
return
[]
class
W4AFp8MoEMethod
(
FusedMoEMethodBase
):
def
interleave_scales
(
scales
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
s_shape
=
scales
.
shape
# Reshape to separate groups of 4
alignment
=
4
if
s_shape
[
2
]
%
4
==
0
else
1
scales_interleaved
=
scales
.
reshape
(
s_shape
[
0
],
s_shape
[
1
],
(
s_shape
[
2
]
//
alignment
),
alignment
)
# Permute dimensions to interleave
scales_interleaved
=
scales_interleaved
.
permute
(
0
,
2
,
1
,
3
)
# Reshape back to original dimensions but with interleaved values
scales_interleaved
=
scales_interleaved
.
reshape
(
s_shape
[
0
],
s_shape
[
2
]
//
alignment
,
s_shape
[
1
]
*
alignment
)
return
scales_interleaved
.
contiguous
()
class
W4AFp8MoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
W4AFp8Config
):
self
.
quant_config
=
quant_config
...
...
@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
return
def
_interleave_scales
(
self
,
scales
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
s_shape
=
scales
.
shape
# Reshape to separate groups of 4
scales_interleaved
=
scales
.
reshape
(
s_shape
[
0
],
s_shape
[
1
],
(
s_shape
[
2
]
//
4
),
4
)
# Permute dimensions to interleave
scales_interleaved
=
scales_interleaved
.
permute
(
0
,
2
,
1
,
3
)
# Reshape back to original dimensions but with interleaved values
scales_interleaved
=
scales_interleaved
.
reshape
(
s_shape
[
0
],
s_shape
[
2
]
//
4
,
s_shape
[
1
]
*
4
)
return
scales_interleaved
.
contiguous
()
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
dtype
=
torch
.
bfloat16
device
=
layer
.
w2_weight
.
device
# Interleave w13_weight_scale (gate_up_proj)
w13_weight_scale
=
layer
.
w13_weight_scale_inv
.
to
(
dtype
)
w13_weight_scale
=
self
.
_
interleave_scales
(
w13_weight_scale
)
w13_weight_scale
=
interleave_scales
(
w13_weight_scale
)
layer
.
w13_weight_scale_inv
=
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
# Interleave w2_weight_scale (down_proj)
w2_weight_scale
=
layer
.
w2_weight_scale_inv
.
to
(
dtype
)
w2_weight_scale
=
self
.
_
interleave_scales
(
w2_weight_scale
)
w2_weight_scale
=
interleave_scales
(
w2_weight_scale
)
layer
.
w2_weight_scale_inv
=
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
# Process input scales
...
...
@@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
topk_weights
,
topk_ids
,
_
=
topk_output
local_topk_ids
=
topk_ids
local_topk_ids
=
torch
.
where
(
topk_ids
==
-
1
,
layer
.
num_experts
,
topk_ids
,
)
if
get_moe_expert_parallel_world_size
()
>
1
:
local_topk_ids
=
torch
.
where
(
topk_ids
==
-
1
,
layer
.
num_experts
,
topk_ids
,
)
output
=
cutlass_w4a8_moe
(
layer
.
start_expert_id
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
d4a93841
...
...
@@ -2185,6 +2185,8 @@ class DeepseekV2ForCausalLM(nn.Module):
disable_reason
=
"Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif
get_moe_expert_parallel_world_size
()
>
1
:
disable_reason
=
"Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
elif
self
.
quant_config
.
get_name
()
==
"w4afp8"
:
disable_reason
=
"Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
if
disable_reason
is
not
None
:
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
True
...
...
@@ -2496,6 +2498,9 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
,
)
# Params for special naming rules in mixed-precision models, for example:
# model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
# see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
if
self
.
quant_config
and
self
.
quant_config
.
get_name
()
==
"w4afp8"
:
expert_params_mapping
+=
FusedMoE
.
make_expert_input_scale_params_mapping
(
num_experts
=
self
.
config
.
n_routed_experts
...
...
python/sglang/test/test_cutlass_w4a8_moe.py
View file @
d4a93841
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
from
typing
import
Literal
,
Optional
import
pytest
import
torch
...
...
@@ -25,7 +25,7 @@ def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Ten
return
packed_tensor
.
to
(
torch
.
int8
)
def
pack_interleave
(
num_experts
,
ref_weight
,
ref_scale
):
def
pack_interleave
(
num_experts
,
ref_weight
,
ref_scale
,
alignment
=
4
):
n
,
k
=
ref_weight
.
shape
[
1
],
ref_weight
.
shape
[
2
]
weight
=
pack_int4_values_to_int8
(
ref_weight
.
cpu
()).
cuda
()
...
...
@@ -33,11 +33,16 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
w_q
=
w_q
.
contiguous
()
scale_interleaved
=
ref_scale
.
reshape
(
ref_scale
.
shape
[
0
],
ref_scale
.
shape
[
1
],
(
ref_scale
.
shape
[
2
]
//
4
),
4
ref_scale
.
shape
[
0
],
ref_scale
.
shape
[
1
],
(
ref_scale
.
shape
[
2
]
//
alignment
),
alignment
,
)
# [E, N, K/4, 4]
scale_interleaved
=
scale_interleaved
.
permute
(
0
,
2
,
1
,
3
)
# [E, K/4, N, 4]
scale_interleaved
=
scale_interleaved
.
reshape
(
ref_scale
.
shape
[
0
],
ref_scale
.
shape
[
2
]
//
4
,
ref_scale
.
shape
[
1
]
*
4
ref_scale
.
shape
[
0
],
ref_scale
.
shape
[
2
]
//
alignment
,
ref_scale
.
shape
[
1
]
*
alignment
,
)
# [E, K/4, N*4]
w_scale
=
scale_interleaved
.
contiguous
()
...
...
@@ -48,12 +53,17 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
@
pytest
.
mark
.
parametrize
(
"N"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
7168
])
@
pytest
.
mark
.
parametrize
(
"E"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"use_ep_moe"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
def
test_cutlass_w4a8_moe
(
M
,
N
,
K
,
E
,
ep_size
,
topk
,
group_size
,
dtype
):
local_e
=
E
//
ep_size
def
test_cutlass_w4a8_moe
(
M
,
N
,
K
,
E
,
tp_size
,
use_ep_moe
,
topk
,
group_size
,
dtype
):
if
use_ep_moe
:
local_e
=
E
//
tp_size
else
:
# tp mode
local_e
=
E
N
=
N
//
tp_size
debug
=
False
if
debug
:
...
...
@@ -87,7 +97,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
)
w1_q
,
w1_scale
=
pack_interleave
(
local_e
,
ref_weight_1
,
scale_1
)
w2_q
,
w2_scale
=
pack_interleave
(
local_e
,
ref_weight_2
,
scale_2
)
if
use_ep_moe
:
w2_q
,
w2_scale
=
pack_interleave
(
local_e
,
ref_weight_2
,
scale_2
)
else
:
w2_q
,
w2_scale
=
pack_interleave
(
local_e
,
ref_weight_2
,
scale_2
,
1
)
device
=
"cuda"
a_strides1
=
torch
.
full
((
local_e
,
3
),
K
,
device
=
device
,
dtype
=
torch
.
int64
)
...
...
@@ -265,7 +278,9 @@ def ref(
gate
,
fc1
=
fc1
.
chunk
(
2
,
dim
=-
1
)
fc1
=
fc1
*
torch
.
nn
.
functional
.
silu
(
gate
)
act
=
(
fc1
/
pre_quant_scale_2
.
float
()).
to
(
torch
.
float8_e4m3fn
)
act
=
torch
.
clamp
((
fc1
/
pre_quant_scale_2
.
float
()),
-
448.0
,
448.0
).
to
(
torch
.
float8_e4m3fn
)
act
=
act
.
to
(
dtype
)
w2
=
ref_weight_2
[
e_idx
]
...
...
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh
View file @
d4a93841
...
...
@@ -31,7 +31,7 @@ __global__ void int4_fp8_get_group_gemm_starts(
b_offsets
[
expert_id
]
=
b_base_as_int
+
expert_id
*
k
*
n
/
2
;
out_offsets
[
expert_id
]
=
out_base_as_int
+
expert_offset
*
n
;
a_scales_offsets
[
expert_id
]
=
a_scales_base_as_int
+
(
per_act_token
?
expert_offset
:
0
);
b_scales_offsets
[
expert_id
]
=
b_scales_base_as_int
+
(
per_out_ch
?
expert_id
*
n
*
4
*
k
/
5
12
:
expert_id
);
b_scales_offsets
[
expert_id
]
=
b_scales_base_as_int
+
(
per_out_ch
?
expert_id
*
n
*
k
/
12
8
:
expert_id
);
}
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
...
...
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu
View file @
d4a93841
...
...
@@ -2,6 +2,8 @@
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <type_traits>
#include "cutlass/cutlass.h"
#include "w4a8_grouped_mm_c3x.cuh"
...
...
@@ -9,38 +11,60 @@ using namespace cute;
namespace
{
#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c
enum
class
Sched
{
PP
,
CO
};
template
<
int
M
,
int
N
,
int
K
,
int
A
,
int
B
,
int
C
,
Sched
S
>
struct
SM90W4A8Config
{
using
KernelSchedule
=
std
::
conditional_t
<
S
==
Sched
::
PP
,
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpong
,
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedCooperative
>
;
#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
using
EpilogueSchedule
=
std
::
conditional_t
<
S
==
Sched
::
PP
,
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedPingpong
,
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedCooperative
>
;
#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
using
TileShape
=
cute
::
Shape
<
cute
::
Int
<
M
>
,
cute
::
Int
<
N
>
,
cute
::
Int
<
K
>>
;
using
ClusterShape
=
cute
::
Shape
<
cute
::
Int
<
A
>
,
cute
::
Int
<
B
>
,
cute
::
Int
<
C
>>
;
using
Cutlass3xW4A8Gemm
=
cutlass_3x_w4a8_group_gemm
<
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
template
<
int
M
,
int
N
,
int
K
,
int
A
,
int
B
,
int
C
>
using
SM90_PP
=
SM90W4A8Config
<
M
,
N
,
K
,
A
,
B
,
C
,
Sched
::
PP
>
;
GENERATE_SM90_W4A8_PP_CONFIG
(
64
,
16
,
512
,
1
,
1
,
1
)
GENERATE_SM90_W4A8_PP_CONFIG
(
64
,
32
,
512
,
2
,
1
,
1
)
template
<
int
M
,
int
N
,
int
K
,
int
A
,
int
B
,
int
C
>
using
SM90_CO
=
SM90W4A8Config
<
M
,
N
,
K
,
A
,
B
,
C
,
Sched
::
CO
>
;
GENERATE_SM90_W4A8_CO_CONFIG
(
128
,
16
,
512
,
1
,
1
,
1
)
GENERATE_SM90_W4A8_CO_CONFIG
(
128
,
16
,
512
,
2
,
1
,
1
)
GENERATE_SM90_W4A8_CO_CONFIG
(
128
,
32
,
512
,
1
,
1
,
1
)
GENERATE_SM90_W4A8_CO_CONFIG
(
128
,
32
,
512
,
2
,
1
,
1
)
GENERATE_SM90_W4A8_CO_CONFIG
(
128
,
64
,
512
,
1
,
1
,
1
)
template
<
typename
Config
>
inline
void
invoke_gemm
(
torch
::
Tensor
&
d_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
d_strides
,
torch
::
Tensor
const
&
s_strides
,
int64_t
chunk_size
)
{
using
GemmT
=
typename
Config
::
Cutlass3xW4A8Gemm
;
cutlass_w4a8_group_gemm_caller
<
GemmT
>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
void
dispatch_w4a8_moe_mm_sm90
(
torch
::
Tensor
&
d_tensors
,
...
...
@@ -56,9 +80,6 @@ void dispatch_w4a8_moe_mm_sm90(
torch
::
Tensor
const
&
s_strides
,
int64_t
chunk_size
,
int64_t
topk
)
{
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedCooperative
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedCooperative
;
uint32_t
const
m
=
a_tensors
.
size
(
0
)
/
topk
;
uint32_t
const
n
=
d_tensors
.
size
(
1
);
uint32_t
const
k
=
a_tensors
.
size
(
1
);
...
...
@@ -66,8 +87,7 @@ void dispatch_w4a8_moe_mm_sm90(
if
(
n
==
4096
&&
k
==
7168
)
{
// group gemm 1
if
(
m
<=
4
)
{
using
Cutlass3xW4A8GemmSelected
=
typename
JOIN_STRUCT_NAME
(
64
,
32
,
512
,
2
,
1
,
1
)
::
Cutlass3xW4A8Gemm
;
cutlass_w4a8_group_gemm_caller
<
Cutlass3xW4A8GemmSelected
>
(
invoke_gemm
<
SM90_PP
<
64
,
32
,
512
,
2
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
...
...
@@ -81,8 +101,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides
,
chunk_size
);
}
else
if
(
m
<=
16
)
{
using
Cutlass3xW4A8GemmSelected
=
typename
JOIN_STRUCT_NAME_CO
(
128
,
16
,
512
,
2
,
1
,
1
)
::
Cutlass3xW4A8Gemm
;
cutlass_w4a8_group_gemm_caller
<
Cutlass3xW4A8GemmSelected
>
(
invoke_gemm
<
SM90_CO
<
128
,
16
,
512
,
2
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
...
...
@@ -96,8 +115,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides
,
chunk_size
);
}
else
if
(
m
<=
256
)
{
using
Cutlass3xW4A8GemmSelected
=
typename
JOIN_STRUCT_NAME_CO
(
128
,
16
,
512
,
1
,
1
,
1
)
::
Cutlass3xW4A8Gemm
;
cutlass_w4a8_group_gemm_caller
<
Cutlass3xW4A8GemmSelected
>
(
invoke_gemm
<
SM90_CO
<
128
,
16
,
512
,
1
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
...
...
@@ -111,8 +129,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides
,
chunk_size
);
}
else
if
(
m
<=
1024
)
{
using
Cutlass3xW4A8GemmSelected
=
typename
JOIN_STRUCT_NAME_CO
(
128
,
32
,
512
,
2
,
1
,
1
)
::
Cutlass3xW4A8Gemm
;
cutlass_w4a8_group_gemm_caller
<
Cutlass3xW4A8GemmSelected
>
(
invoke_gemm
<
SM90_CO
<
128
,
32
,
512
,
2
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
...
...
@@ -126,8 +143,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides
,
chunk_size
);
}
else
{
using
Cutlass3xW4A8GemmSelected
=
typename
JOIN_STRUCT_NAME_CO
(
128
,
64
,
512
,
1
,
1
,
1
)
::
Cutlass3xW4A8Gemm
;
cutlass_w4a8_group_gemm_caller
<
Cutlass3xW4A8GemmSelected
>
(
invoke_gemm
<
SM90_CO
<
128
,
64
,
512
,
1
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
...
...
@@ -144,8 +160,125 @@ void dispatch_w4a8_moe_mm_sm90(
}
else
if
(
n
==
7168
&&
k
==
2048
)
{
// group gemm 2
if
(
m
<=
8
)
{
using
Cutlass3xW4A8GemmSelected
=
typename
JOIN_STRUCT_NAME
(
64
,
16
,
512
,
1
,
1
,
1
)
::
Cutlass3xW4A8Gemm
;
cutlass_w4a8_group_gemm_caller
<
Cutlass3xW4A8GemmSelected
>
(
invoke_gemm
<
SM90_PP
<
64
,
16
,
512
,
1
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
else
if
(
m
<=
512
)
{
invoke_gemm
<
SM90_CO
<
128
,
32
,
512
,
1
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
else
{
invoke_gemm
<
SM90_CO
<
128
,
64
,
512
,
1
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
}
else
if
(
n
==
512
&&
k
==
7168
)
{
// group gemm 1 for tp
if
(
m
<=
4
)
{
invoke_gemm
<
SM90_PP
<
64
,
32
,
512
,
2
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
else
if
(
m
<=
16
)
{
invoke_gemm
<
SM90_CO
<
128
,
16
,
512
,
2
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
else
if
(
m
<=
256
)
{
invoke_gemm
<
SM90_CO
<
128
,
16
,
512
,
2
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
else
if
(
m
<=
1024
)
{
invoke_gemm
<
SM90_CO
<
128
,
32
,
512
,
2
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
else
{
invoke_gemm
<
SM90_CO
<
128
,
64
,
512
,
1
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
}
else
if
(
n
==
7168
&&
k
==
256
)
{
// group gemm 2 for tp
if
(
m
<=
8
)
{
invoke_gemm
<
SM90_PP
<
64
,
16
,
128
,
1
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
...
...
@@ -159,8 +292,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides
,
chunk_size
);
}
else
if
(
m
<=
512
)
{
using
Cutlass3xW4A8GemmSelected
=
typename
JOIN_STRUCT_NAME_CO
(
128
,
32
,
512
,
1
,
1
,
1
)
::
Cutlass3xW4A8Gemm
;
cutlass_w4a8_group_gemm_caller
<
Cutlass3xW4A8GemmSelected
>
(
invoke_gemm
<
SM90_PP
<
128
,
32
,
128
,
2
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
...
...
@@ -174,8 +306,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides
,
chunk_size
);
}
else
{
using
Cutlass3xW4A8GemmSelected
=
typename
JOIN_STRUCT_NAME_CO
(
128
,
64
,
512
,
1
,
1
,
1
)
::
Cutlass3xW4A8Gemm
;
cutlass_w4a8_group_gemm_caller
<
Cutlass3xW4A8GemmSelected
>
(
invoke_gemm
<
SM90_PP
<
128
,
64
,
128
,
1
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
...
...
@@ -190,20 +321,35 @@ void dispatch_w4a8_moe_mm_sm90(
chunk_size
);
}
}
else
{
using
Cutlass3xW4A8GemmSelected
=
typename
JOIN_STRUCT_NAME_CO
(
128
,
32
,
512
,
1
,
1
,
1
)
::
Cutlass3xW4A8Gemm
;
cutlass_w4a8_group_gemm_caller
<
Cutlass3xW4A8GemmSelected
>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
if
(
k
%
512
==
0
)
{
invoke_gemm
<
SM90_CO
<
128
,
32
,
512
,
1
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
else
{
invoke_gemm
<
SM90_PP
<
128
,
64
,
128
,
1
,
1
,
1
>>
(
d_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
d_strides
,
s_strides
,
chunk_size
);
}
}
}
...
...
sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh
View file @
d4a93841
...
...
@@ -41,9 +41,8 @@ using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
using
QuantType
=
cutlass
::
int4b_t
;
// 4-bit integer type
using
ElementAccumulator
=
float
;
// Accumulator type
using
ElementScale
=
cutlass
::
bfloat16_t
;
// Scale type
using
ElementScalePacked
=
cutlass
::
Array
<
ElementScale
,
4
>
;
using
ElementC
=
cutlass
::
half_t
;
// Default output type (FP16)
using
ElementD
=
ElementC
;
// Default output type (FP16)
using
ElementC
=
cutlass
::
half_t
;
// Default output type (FP16)
using
ElementD
=
ElementC
;
// Default output type (FP16)
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int
,
int
,
int
>>
;
// Architecture-specific configurations
...
...
@@ -73,6 +72,10 @@ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
template
<
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_3x_w4a8_group_gemm
{
static
constexpr
int
GroupSize
=
128
;
static
constexpr
int
PackedScalesNum
=
get
<
2
>
(
TileShape
{})
/
GroupSize
;
using
ElementScalePacked
=
cutlass
::
Array
<
ElementScale
,
PackedScalesNum
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
...
...
@@ -184,8 +187,6 @@ void cutlass_w4a8_group_gemm_caller(
TORCH_CHECK
(
b_tensors
.
size
(
0
)
==
num_experts
,
"B tensor first dimension must match number of groups"
);
TORCH_CHECK
(
b_scales
.
size
(
0
)
==
num_experts
,
"Scale tensor first dimension must match number of groups"
);
TORCH_CHECK
(
b_tensors
.
size
(
2
)
*
2
==
a_tensors
.
size
(
1
),
"B tensor K/2 dimension must match A tensor K dimension"
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
a_tensors
.
size
(
1
)
/
512
,
"Scale tensor second dimension must be K//512"
);
TORCH_CHECK
(
b_scales
.
size
(
2
)
==
4
*
b_tensors
.
size
(
1
),
"Scale tensor last dimension must be 4*N"
);
// Check tensor types
TORCH_CHECK
(
a_tensors
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"A tensor must be fp8 (float_e4m3_t) type"
);
...
...
@@ -241,7 +242,7 @@ void cutlass_w4a8_group_gemm_caller(
static_cast
<
typename
Gemm
::
StrideB
*>
(
b_strides
.
data_ptr
()),
static_cast
<
const
MmaType
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
typename
Gemm
::
StrideA
*>
(
a_strides
.
data_ptr
()),
static_cast
<
const
ElementScalePacked
**>
(
b_scales_ptrs
.
data_ptr
()),
static_cast
<
const
typename
Gemm
::
ElementScalePacked
**>
(
b_scales_ptrs
.
data_ptr
()),
static_cast
<
typename
Gemm
::
StrideS
*>
(
s_strides
.
data_ptr
()),
static_cast
<
int
>
(
chunk_size
)},
{
fusion_args
,
...
...
sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py
View file @
d4a93841
...
...
@@ -27,12 +27,18 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
w_q
=
weight
.
view
((
num_experts
,
n
,
k
//
2
)).
view
(
torch
.
int8
)
w_q
=
w_q
.
contiguous
()
alignment
=
4
if
k
%
512
==
0
else
1
scale_interleaved
=
ref_scale
.
reshape
(
ref_scale
.
shape
[
0
],
ref_scale
.
shape
[
1
],
(
ref_scale
.
shape
[
2
]
//
4
),
4
ref_scale
.
shape
[
0
],
ref_scale
.
shape
[
1
],
(
ref_scale
.
shape
[
2
]
//
alignment
),
alignment
,
)
# [E, N, K/4, 4]
scale_interleaved
=
scale_interleaved
.
permute
(
0
,
2
,
1
,
3
)
# [E, K/4, N, 4]
scale_interleaved
=
scale_interleaved
.
reshape
(
ref_scale
.
shape
[
0
],
ref_scale
.
shape
[
2
]
//
4
,
ref_scale
.
shape
[
1
]
*
4
ref_scale
.
shape
[
0
],
ref_scale
.
shape
[
2
]
//
alignment
,
ref_scale
.
shape
[
1
]
*
alignment
,
)
# [E, K/4, N*4]
w_scale
=
scale_interleaved
.
contiguous
()
...
...
@@ -137,8 +143,8 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
reason
=
"cutlass_w4a8_moe_mm is only supported on sm90"
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
,
4
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
1024
,
2048
,
7168
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
2
,
4
,
6
,
8
])
def
test_int4_fp8_grouped_gemm_multi_experts
(
batch_size
,
k
,
n
,
num_experts
):
torch
.
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment