Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0711d150
Unverified
Commit
0711d150
authored
Nov 04, 2025
by
Kaixi Hou
Committed by
GitHub
Nov 04, 2025
Browse files
[NVIDIA] Fix cutedsl backend of MoE (#12353)
parent
09938e1f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
579 additions
and
28 deletions
+579
-28
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+2
-2
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
+9
-9
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+6
-2
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-1
scripts/ci/ci_install_deepep.sh
scripts/ci/ci_install_deepep.sh
+58
-8
test/srt/run_suite.py
test/srt/run_suite.py
+3
-0
test/srt/test_cutedsl_moe.py
test/srt/test_cutedsl_moe.py
+482
-0
test/srt/test_deepseek_v3_cutedsl_4gpu.py
test/srt/test_deepseek_v3_cutedsl_4gpu.py
+18
-6
No files found.
.github/workflows/pr-test.yml
View file @
0711d150
...
@@ -821,7 +821,7 @@ jobs:
...
@@ -821,7 +821,7 @@ jobs:
python3 run_suite.py --suite per-commit-4-gpu-b200 --auto-partition-id 0 --auto-partition-size 1 --timeout-per-file 3600
python3 run_suite.py --suite per-commit-4-gpu-b200 --auto-partition-id 0 --auto-partition-size 1 --timeout-per-file 3600
unit-test-backend-4-gpu-gb200
:
unit-test-backend-4-gpu-gb200
:
needs
:
[
check-changes
,
unit-test-backend-2-gpu
,
sgl-kernel-build-wheels-arm
]
needs
:
[
check-changes
,
sgl-kernel-build-wheels-arm
]
if
:
always() && !failure() && !cancelled() &&
if
:
always() && !failure() && !cancelled() &&
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
runs-on
:
4-gpu-gb200
runs-on
:
4-gpu-gb200
...
@@ -841,7 +841,7 @@ jobs:
...
@@ -841,7 +841,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
run
:
|
run
:
|
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} IS_BLACKWELL=1 bash scripts/ci/ci_install_de
pendency
.sh
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} IS_BLACKWELL=1
GRACE_BLACKWELL=1
bash scripts/ci/ci_install_de
epep
.sh
-
name
:
Run test
-
name
:
Run test
timeout-minutes
:
45
timeout-minutes
:
45
...
...
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
View file @
0711d150
from
typing
import
Optional
,
Union
from
typing
import
Optional
import
torch
import
torch
from
flashinfer.cute_dsl.blockscaled_gemm
import
grouped_gemm_nt_masked
from
flashinfer.cute_dsl.blockscaled_gemm
import
grouped_gemm_nt_masked
...
@@ -20,7 +20,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
...
@@ -20,7 +20,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
def
flashinfer_cutedsl_moe_masked
(
def
flashinfer_cutedsl_moe_masked
(
hidden_states
:
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
hidden_states
:
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
input_global_scale
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
...
@@ -40,7 +40,7 @@ def flashinfer_cutedsl_moe_masked(
...
@@ -40,7 +40,7 @@ def flashinfer_cutedsl_moe_masked(
Args:
Args:
hidden_states: Either of the following case
hidden_states: Either of the following case
* torch.Tensor: [num_experts, m, k], bf16
*
tuple[
torch.Tensor
, None]
: [num_experts, m, k], bf16
, None means no quant
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn
input_global_scale (torch.Tensor): (l,)
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
...
@@ -74,21 +74,21 @@ def flashinfer_cutedsl_moe_masked(
...
@@ -74,21 +74,21 @@ def flashinfer_cutedsl_moe_masked(
assert
(
assert
(
w2_alpha
.
dtype
==
torch
.
float32
w2_alpha
.
dtype
==
torch
.
float32
),
f
"w2_alpha must be float32, got
{
w2_alpha
.
dtype
}
"
),
f
"w2_alpha must be float32, got
{
w2_alpha
.
dtype
}
"
assert
(
len
(
hidden_states
)
==
2
),
f
"hidden_states must be a tuple of length 2, got
{
len
(
hidden_states
)
}
"
# === Assertions on shapes ===
# === Assertions on shapes ===
n
=
w2
.
shape
[
-
1
]
*
2
# intermediate dimension
n
=
w2
.
shape
[
-
1
]
*
2
# intermediate dimension
if
isinstance
(
hidden_states
,
tuple
):
if
hidden_states
[
1
]
is
not
None
:
assert
(
input_global_scale
is
None
),
"input_global_scale is needed when input needs quant"
a_q
=
hidden_states
[
0
].
view
(
torch
.
uint8
)
a_q
=
hidden_states
[
0
].
view
(
torch
.
uint8
)
a_q_sf
=
hidden_states
[
1
].
view
(
torch
.
float8_e4m3fn
)
a_q_sf
=
hidden_states
[
1
].
view
(
torch
.
float8_e4m3fn
)
m
,
k_by_2
,
num_experts
=
a_q
.
shape
m
,
k_by_2
,
num_experts
=
a_q
.
shape
k
=
k_by_2
*
2
k
=
k_by_2
*
2
else
:
else
:
num_experts
,
m
,
k
=
hidden_states
.
shape
num_experts
,
m
,
k
=
hidden_states
[
0
]
.
shape
assert
(
assert
(
input_global_scale
.
dtype
==
torch
.
float32
input_global_scale
.
dtype
==
torch
.
float32
...
@@ -98,7 +98,7 @@ def flashinfer_cutedsl_moe_masked(
...
@@ -98,7 +98,7 @@ def flashinfer_cutedsl_moe_masked(
),
f
"input_global_scale must be (l,), got
{
input_global_scale
.
shape
}
"
),
f
"input_global_scale must be (l,), got
{
input_global_scale
.
shape
}
"
a_q
,
a_q_sf
=
scaled_fp4_grouped_quant
(
a_q
,
a_q_sf
=
scaled_fp4_grouped_quant
(
hidden_states
,
hidden_states
[
0
]
,
input_global_scale
,
input_global_scale
,
masked_m
,
masked_m
,
)
)
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
0711d150
...
@@ -1451,7 +1451,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1451,7 +1451,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
)
layer
.
dispatcher
.
set_quant_config
(
layer
.
dispatcher
.
set_quant_config
(
{
"input_global_scale"
:
layer
.
w13_input_scale_quant
}
{
"input_global_scale"
:
(
layer
.
w13_input_scale_quant
if
CUTEDSL_MOE_NVFP4_DISPATCH
else
None
)
}
)
)
# Validate weight scales
# Validate weight scales
...
@@ -1688,7 +1692,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1688,7 +1692,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
def
apply_without_routing_weights
(
def
apply_without_routing_weights
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
x
:
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
masked_m
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
moe_runner_config
:
MoeRunnerConfig
,
moe_runner_config
:
MoeRunnerConfig
,
down_gemm_overlap_args
:
Optional
[
"DownGemmOverlapArgs"
],
down_gemm_overlap_args
:
Optional
[
"DownGemmOverlapArgs"
],
...
...
python/sglang/test/test_utils.py
View file @
0711d150
...
@@ -57,7 +57,7 @@ DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
...
@@ -57,7 +57,7 @@ DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
=
"lmsys/sglang-ci-dsv3-test-NextN"
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
=
"lmsys/sglang-ci-dsv3-test-NextN"
# NVFP4 models
# NVFP4 models
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST
=
"nvidia/DeepSeek-
R1-0528
-FP4"
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST
=
"nvidia/DeepSeek-
V3-0324
-FP4"
# FP8 models
# FP8 models
DEFAULT_MODEL_NAME_FOR_TEST_FP8
=
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_FP8
=
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
...
...
scripts/ci/ci_install_deepep.sh
View file @
0711d150
...
@@ -10,9 +10,20 @@ export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
...
@@ -10,9 +10,20 @@ export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
export
PATH
=
"
${
NVSHMEM_DIR
}
/bin:
$PATH
"
export
PATH
=
"
${
NVSHMEM_DIR
}
/bin:
$PATH
"
export
CUDA_HOME
=
/usr/local/cuda
export
CUDA_HOME
=
/usr/local/cuda
if
python3
-c
"import deep_ep"
>
/dev/null 2>&1
;
then
GRACE_BLACKWELL
=
${
GRACE_BLACKWELL
:-
0
}
echo
"deep_ep is already installed or importable. Skipping installation."
# Detect architecture
exit
0
ARCH
=
$(
uname
-m
)
if
[
"
$ARCH
"
!=
"x86_64"
]
&&
[
"
$ARCH
"
!=
"aarch64"
]
;
then
echo
"Unsupported architecture:
$ARCH
"
exit
1
fi
# It seems GB200 ci runner preinstalls some wrong version of deep_ep, so we cannot rely on it.
if
[
"
$GRACE_BLACKWELL
"
!=
"1"
]
;
then
if
python3
-c
"import deep_ep"
>
/dev/null 2>&1
;
then
echo
"deep_ep is already installed or importable. Skipping installation."
exit
0
fi
fi
fi
# Install system dependencies
# Install system dependencies
...
@@ -35,8 +46,10 @@ dpkg -i libgdrapi_*.deb
...
@@ -35,8 +46,10 @@ dpkg -i libgdrapi_*.deb
dpkg
-i
gdrcopy-tests_
*
.deb
dpkg
-i
gdrcopy-tests_
*
.deb
dpkg
-i
gdrcopy_
*
.deb
dpkg
-i
gdrcopy_
*
.deb
if
[
!
-e
"/usr/lib/x86_64-linux-gnu/libmlx5.so"
]
;
then
# Set up library paths based on architecture
ln
-s
/usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so
LIB_PATH
=
"/usr/lib/
$ARCH
-linux-gnu"
if
[
!
-e
"
$LIB_PATH
/libmlx5.so"
]
;
then
ln
-s
$LIB_PATH
/libmlx5.so.1
$LIB_PATH
/libmlx5.so
fi
fi
apt-get update
&&
apt-get
install
-y
libfabric-dev
apt-get update
&&
apt-get
install
-y
libfabric-dev
...
@@ -45,6 +58,11 @@ cd /opt/nvshmem
...
@@ -45,6 +58,11 @@ cd /opt/nvshmem
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.4.5/source/nvshmem_src_cuda12-all-all-3.4.5.tar.gz
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.4.5/source/nvshmem_src_cuda12-all-all-3.4.5.tar.gz
tar
-xf
nvshmem_src_cuda12-all-all-3.4.5.tar.gz
tar
-xf
nvshmem_src_cuda12-all-all-3.4.5.tar.gz
mv
nvshmem_src nvshmem
&&
cd
nvshmem
mv
nvshmem_src nvshmem
&&
cd
nvshmem
if
[
"
$GRACE_BLACKWELL
"
=
"1"
]
;
then
CUDA_ARCH
=
"100;120"
else
CUDA_ARCH
=
"90"
fi
NVSHMEM_SHMEM_SUPPORT
=
0
\
NVSHMEM_SHMEM_SUPPORT
=
0
\
NVSHMEM_UCX_SUPPORT
=
0
\
NVSHMEM_UCX_SUPPORT
=
0
\
NVSHMEM_USE_NCCL
=
0
\
NVSHMEM_USE_NCCL
=
0
\
...
@@ -53,13 +71,45 @@ NVSHMEM_IBGDA_SUPPORT=1 \
...
@@ -53,13 +71,45 @@ NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT
=
0
\
NVSHMEM_PMIX_SUPPORT
=
0
\
NVSHMEM_TIMEOUT_DEVICE_POLLING
=
0
\
NVSHMEM_TIMEOUT_DEVICE_POLLING
=
0
\
NVSHMEM_USE_GDRCOPY
=
1
\
NVSHMEM_USE_GDRCOPY
=
1
\
cmake
-S
.
-B
build/
-DCMAKE_INSTALL_PREFIX
=
/opt/nvshmem/install
-DCMAKE_CUDA_ARCHITECTURES
=
90
cmake
-S
.
-B
build/
-DCMAKE_INSTALL_PREFIX
=
/opt/nvshmem/install
-DCMAKE_CUDA_ARCHITECTURES
=
${
CUDA_ARCH
}
cd
build
cd
build
make
-j
$(
nproc
)
install
make
-j
$(
nproc
)
install
# Install DeepEP
# Install DeepEP
rm
-rf
/root/.cache/deepep
&&
git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep
&&
cd
/root/.cache/deepep
&&
git checkout 9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee
DEEPEP_DIR
=
/root/.cache/deepep
cd
/root/.cache/deepep
&&
python3 setup.py
install
rm
-rf
${
DEEPEP_DIR
}
if
[
"
$GRACE_BLACKWELL
"
=
"1"
]
;
then
# We use Tom's DeepEP fork for GB200 for now, which supports fp4 dispatch.
GRACE_BLACKWELL_DEEPEP_BRANCH
=
gb200_blog_part_2
git clone https://github.com/fzyzcjy/DeepEP.git
${
DEEPEP_DIR
}
&&
\
pushd
${
DEEPEP_DIR
}
&&
\
git checkout
${
GRACE_BLACKWELL_DEEPEP_BRANCH
}
&&
\
sed
-i
's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/'
csrc/kernels/configs.cuh
&&
\
popd
else
git clone https://github.com/deepseek-ai/DeepEP.git
${
DEEPEP_DIR
}
&&
\
pushd
${
DEEPEP_DIR
}
&&
\
git checkout 9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee
&&
\
popd
fi
cd
${
DEEPEP_DIR
}
if
[
"
$GRACE_BLACKWELL
"
=
"1"
]
;
then
CUDA_VERSION
=
$(
nvidia-smi |
grep
"CUDA Version"
|
head
-n1
|
awk
'{print $9}'
)
if
[
"
$CUDA_VERSION
"
=
"12.8"
]
;
then
CHOSEN_TORCH_CUDA_ARCH_LIST
=
'10.0'
elif
awk
-v
ver
=
"
$CUDA_VERSION
"
'BEGIN {exit !(ver > 12.8)}'
;
then
CHOSEN_TORCH_CUDA_ARCH_LIST
=
'10.0;10.3'
else
echo
"Unsupported CUDA version for Grace Blackwell:
$CUDA_VERSION
"
&&
exit
1
fi
&&
\
if
[
"
${
CUDA_VERSION
%%.*
}
"
=
"13"
]
;
then
\
sed
-i
"/^ include_dirs =
\[
'csrc
\/
'
\]
/a
\
include_dirs.append('
${
CUDA_HOME
}
/include/cccl')"
setup.py
;
\
fi
NVSHMEM_DIR
=
/opt/nvshmem/install
TORCH_CUDA_ARCH_LIST
=
"
${
CHOSEN_TORCH_CUDA_ARCH_LIST
}
"
pip
install
--no-build-isolation
.
else
python3 setup.py
install
fi
# Verify configuration
# Verify configuration
echo
"=== Verify NVSHMEM ==="
echo
"=== Verify NVSHMEM ==="
...
...
test/srt/run_suite.py
View file @
0711d150
...
@@ -179,7 +179,10 @@ suites = {
...
@@ -179,7 +179,10 @@ suites = {
TestFile
(
"test_llama31_fp4.py"
,
300
),
TestFile
(
"test_llama31_fp4.py"
,
300
),
],
],
"per-commit-4-gpu-gb200"
:
[
"per-commit-4-gpu-gb200"
:
[
TestFile
(
"test_cutedsl_moe.py"
,
300
),
TestFile
(
"test_deepseek_v3_fp4_4gpu.py"
,
3600
),
TestFile
(
"test_deepseek_v3_fp4_4gpu.py"
,
3600
),
# Disabled temporarily, see https://github.com/sgl-project/sglang/issues/12533
# TestFile("test_deepseek_v3_cutedsl_4gpu.py", 3600),
],
],
"per-commit-4-gpu-deepep"
:
[
"per-commit-4-gpu-deepep"
:
[
TestFile
(
"ep/test_deepep_small.py"
,
531
),
TestFile
(
"ep/test_deepep_small.py"
,
531
),
...
...
test/srt/test_cutedsl_moe.py
0 → 100644
View file @
0711d150
# SPDX-License-Identifier: Apache-2.0
import
unittest
from
typing
import
Callable
import
torch
from
flashinfer
import
fp4_quantize
from
sgl_kernel
import
scaled_fp4_grouped_quant
,
scaled_fp4_quant
from
torch.nn
import
functional
as
F
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.flashinfer_cutedsl_moe
import
flashinfer_cutedsl_moe_masked
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
SKIP_TEST
=
torch
.
cuda
.
get_device_capability
()
<
(
10
,
0
)
SKIP_REASON
=
"Nvfp4 Requires compute capability of 10 or above."
kE2M1ToFloat
=
torch
.
tensor
(
[
0.0
,
0.5
,
1.0
,
1.5
,
2.0
,
3.0
,
4.0
,
6.0
],
dtype
=
torch
.
float32
)
FLOAT8_E4M3_MAX
=
448.0
FLOAT4_E2M1_MAX
=
6.0
def
convert_swizzled_to_linear
(
a_sf_swizzled
:
torch
.
Tensor
,
m
,
k
,
block_size
):
m_tiles
=
(
m
+
128
-
1
)
//
128
f
=
block_size
*
4
k_tiles
=
(
k
+
f
-
1
)
//
f
tmp
=
torch
.
reshape
(
a_sf_swizzled
,
(
1
,
m_tiles
,
k_tiles
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
out
=
tmp
.
reshape
(
m_tiles
*
128
,
k_tiles
*
f
//
block_size
)
return
out
[
0
:
m
,
0
:
k
]
def
dequantize_nvfp4_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
m
,
packed_k
=
tensor_fp4
.
shape
k
=
packed_k
*
2
tensor_f32
=
break_fp4_bytes
(
tensor_fp4
,
dtype
)
tensor_f32
=
tensor_f32
.
reshape
(
m
,
k
//
block_size
,
block_size
)
tensor_sf
=
tensor_sf
.
view
(
torch
.
float8_e4m3fn
)
tensor_sf
=
convert_swizzled_to_linear
(
tensor_sf
,
m
,
k
,
block_size
)
tensor_sf_dtype
=
tensor_sf
.
to
(
torch
.
float32
)
/
global_scale
# scale the tensor
out
=
(
tensor_f32
*
tensor_sf_dtype
.
unsqueeze
(
-
1
)).
reshape
(
m
,
k
)
return
out
.
to
(
dtype
=
dtype
)
def
break_fp4_bytes
(
a
,
dtype
):
assert
a
.
dtype
==
torch
.
uint8
m
,
n
=
a
.
shape
# Vectorized nibble processing
a_flat
=
a
.
flatten
()
high
=
(
a_flat
&
0xF0
)
>>
4
# Upper nibbles
low
=
a_flat
&
0x0F
# Lower nibbles
# Combine nibbles for batch processing
combined
=
torch
.
stack
((
low
,
high
),
dim
=
1
).
flatten
()
# Vectorized sign and magnitude extraction
signs
=
(
combined
&
0x08
).
to
(
torch
.
bool
)
# Sign bits
abs_vals
=
(
combined
&
0x07
).
to
(
torch
.
long
)
# Magnitude indices
# Device-aware lookup and sign application
kE2M1
=
kE2M1ToFloat
.
to
(
device
=
a
.
device
)
values
=
kE2M1
[
abs_vals
]
*
torch
.
where
(
signs
,
-
1.0
,
1.0
)
# Reshape to final form
return
values
.
reshape
(
m
,
n
*
2
).
to
(
dtype
=
dtype
)
def
compute_routing
(
router_logits
:
torch
.
Tensor
,
top_k
:
int
):
routing_weights
=
torch
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
top_k
,
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
routing_weights
=
routing_weights
.
float
()
return
routing_weights
,
selected_experts
def
prepare_inputs
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
num_experts
:
int
,
topk
:
int
,
):
routing_weights
,
topk_idx
=
compute_routing
(
router_logits
,
topk
)
masked_m
=
[]
for
i
in
range
(
num_experts
):
mask
=
topk_idx
.
view
(
-
1
)
==
i
masked_m
.
append
(
mask
.
sum
())
masked_m
=
torch
.
tensor
(
masked_m
,
dtype
=
torch
.
int32
)
hidden_states_3d
=
torch
.
empty
(
(
num_experts
,
max
(
masked_m
),
hidden_states
.
shape
[
1
]),
dtype
=
hidden_states
.
dtype
)
for
i
in
range
(
num_experts
):
hidden_states_3d
[
i
,
:
masked_m
[
i
],
:]
=
hidden_states
[
topk_idx
.
view
(
-
1
)
==
i
]
return
hidden_states_3d
,
masked_m
,
topk_idx
,
routing_weights
MNK_FACTORS
=
[
(
2
,
1024
,
1024
),
(
2
,
1024
,
1536
),
(
2
,
3072
,
1024
),
(
2
,
3072
,
1536
),
(
64
,
1024
,
1024
),
(
64
,
1024
,
1536
),
(
64
,
3072
,
1024
),
(
64
,
2048
,
1024
),
(
224
,
1024
,
1024
),
(
224
,
1024
,
1536
),
]
# Reference implementation of torch_moe
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
expert_map
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
()(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
def
torch_moe_nvfp4
(
a
,
w1
,
w2
,
topk
,
topk_weight
,
topk_ids
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
m
=
w1
[
i
].
shape
[
0
]
assert
m
%
2
==
0
# Note: w1 and w3 are swapped!
w3_expert
,
w1_expert
=
w1
[
i
][
m
//
2
:,
:],
w1
[
i
][:
m
//
2
,
:]
inter
=
F
.
silu
(
a
[
mask
]
@
w1_expert
.
t
())
*
(
a
[
mask
]
@
w3_expert
.
t
())
inter_gs
=
torch
.
tensor
(
1.0
).
cuda
()
inter_q
,
inter_blockscale
=
fp4_quantize
(
inter
,
inter_gs
)
inter
=
dequantize_nvfp4_to_dtype
(
inter_q
,
inter_blockscale
,
inter_gs
,
dtype
=
inter
.
dtype
,
device
=
inter
.
device
,
block_size
=
16
,
).
cuda
()
out
[
mask
]
=
inter
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
def
check_moe
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
moe_impl
:
Callable
,
flip_w13
:
bool
,
):
torch
.
manual_seed
(
7
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
quant_blocksize
=
16
round_up
=
lambda
x
,
y
:
(
x
+
y
-
1
)
//
y
*
y
sf_w1_2n
=
round_up
(
2
*
n
,
128
)
sf_w1_k
=
round_up
(
k
//
quant_blocksize
,
4
)
w1_blockscale
=
torch
.
empty
(
(
e
,
sf_w1_2n
,
sf_w1_k
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
sf_w2_k
=
round_up
(
k
,
128
)
sf_w2_n
=
round_up
(
n
//
quant_blocksize
,
4
)
w2_blockscale
=
torch
.
empty
(
(
e
,
sf_w2_k
,
sf_w2_n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w1_q
=
torch
.
empty
((
e
,
2
*
n
,
k
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w2_q
=
torch
.
empty
((
e
,
k
,
n
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w1_gs
=
torch
.
empty
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_gs
=
torch
.
empty
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
e
):
w1_amax
=
torch
.
abs
(
w1
).
max
().
to
(
torch
.
float32
)
w2_amax
=
torch
.
abs
(
w2
).
max
().
to
(
torch
.
float32
)
w1_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
w1_q
[
expert
],
w1_blockscale
[
expert
]
=
scaled_fp4_quant
(
w1
[
expert
],
w1_gs
[
expert
]
)
w2_q
[
expert
],
w2_blockscale
[
expert
]
=
scaled_fp4_quant
(
w2
[
expert
],
w2_gs
[
expert
]
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_output
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
topk_config
=
TopKConfig
(
top_k
=
topk
,
renormalize
=
False
),
)
topk_weights
,
topk_ids
,
_
=
topk_output
a1_gs
=
torch
.
ones
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_gs
=
torch
.
ones
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
test_output
=
moe_impl
(
a
=
a
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
w1_q
=
w1_q
,
w2_q
=
w2_q
,
a1_gs
=
a1_gs
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
(
1
/
w1_gs
),
a2_gs
=
a2_gs
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
(
1
/
w2_gs
),
)
# Reference check:
a_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
a
,
a_global_scale
)
_
,
m_k
=
a_fp4
.
shape
a_in_dtype
=
dequantize_nvfp4_to_dtype
(
a_fp4
,
a_scale_interleaved
,
a_global_scale
,
dtype
=
a
.
dtype
,
device
=
a
.
device
,
block_size
=
quant_blocksize
,
)
w1_d
=
torch
.
empty
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
w2_d
=
torch
.
empty
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
for
idx
in
range
(
0
,
e
):
w1_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w1_q
[
idx
],
w1_blockscale
[
idx
],
w1_gs
[
idx
],
dtype
=
w1
.
dtype
,
device
=
w1
.
device
,
block_size
=
quant_blocksize
,
)
w2_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w2_q
[
idx
],
w2_blockscale
[
idx
],
w2_gs
[
idx
],
dtype
=
w2
.
dtype
,
device
=
w2
.
device
,
block_size
=
quant_blocksize
,
)
if
flip_w13
:
dim
=
-
2
size
=
w1_d
.
size
(
dim
)
assert
size
%
2
==
0
,
f
"Expected even size in dim
{
dim
}
, got
{
size
}
"
half
=
size
//
2
# Reorder weight
w1
,
w3
=
w1_d
.
split
(
half
,
dim
=
dim
)
w1_d
=
torch
.
cat
([
w3
,
w1
],
dim
=
dim
).
contiguous
()
torch_output
=
torch_moe
(
a_in_dtype
,
w1_d
,
w2_d
,
score
,
topk
,
None
)
torch
.
testing
.
assert_close
(
torch_output
,
test_output
,
atol
=
1e-1
,
rtol
=
1e-1
)
class
TestFlashinferCutedslMoe
(
unittest
.
TestCase
):
@
unittest
.
skipIf
(
SKIP_TEST
,
SKIP_REASON
)
def
test_flashinfer_cutedsl_moe_masked
(
self
):
# Test parameters
test_cases
=
[
(
2
,
128
,
256
,
1
),
(
2
,
128
,
256
,
2
),
(
2
,
128
,
256
,
4
),
(
16
,
128
,
512
,
1
),
(
16
,
128
,
512
,
2
),
(
16
,
128
,
512
,
4
),
]
for
bs
,
hidden_dim
,
inter_dim
,
topk
in
test_cases
:
with
self
.
subTest
(
bs
=
bs
,
hidden_dim
=
hidden_dim
,
inter_dim
=
inter_dim
,
topk
=
topk
):
print
(
f
"Testing with bs=
{
bs
}
, hidden_dim=
{
hidden_dim
}
, inter_dim=
{
inter_dim
}
, topk=
{
topk
}
"
)
with
torch
.
inference_mode
():
torch
.
manual_seed
(
42
)
device
=
"cuda"
dtype
=
torch
.
bfloat16
num_experts
=
8
hidden_states
=
(
torch
.
randn
(
bs
,
hidden_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
/
5.0
)
w1
=
(
torch
.
randn
(
num_experts
,
2
*
inter_dim
,
hidden_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
)
/
10.0
)
w2
=
(
torch
.
randn
(
num_experts
,
hidden_dim
,
inter_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
)
/
10.0
)
router_logits
=
torch
.
randn
(
bs
,
num_experts
,
dtype
=
torch
.
float32
)
hidden_states_expanded
=
(
hidden_states
.
view
(
bs
,
-
1
,
hidden_dim
)
.
repeat
(
1
,
topk
,
1
)
.
reshape
(
-
1
,
hidden_dim
)
)
hidden_states_3d
,
masked_m
,
topk_idx
,
routing_weights
=
(
prepare_inputs
(
hidden_states_expanded
,
router_logits
,
num_experts
,
topk
)
)
w1_amax
=
w1
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
).
to
(
w1
.
device
)
w2_amax
=
w2
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
).
to
(
w2
.
device
)
input_global_scale
=
torch
.
ones
(
(
num_experts
,),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
w1_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
a2_global_scale
=
torch
.
ones
(
(
num_experts
,),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
# assume intermediate scale is 1.0
w1_fp4
,
w1_blockscale
=
scaled_fp4_grouped_quant
(
w1
,
w1_global_scale
,
torch
.
ones
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
w1
.
device
)
*
2
*
inter_dim
,
)
w2_fp4
,
w2_blockscale
=
scaled_fp4_grouped_quant
(
w2
,
w2_global_scale
,
torch
.
ones
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
w2
.
device
)
*
hidden_dim
,
)
w1_alpha
=
1.0
/
(
input_global_scale
*
w1_global_scale
)
w2_alpha
=
1.0
/
(
a2_global_scale
*
w2_global_scale
)
out
=
flashinfer_cutedsl_moe_masked
(
(
hidden_states_3d
.
to
(
hidden_states
.
device
),
None
),
input_global_scale
,
w1_fp4
.
permute
(
2
,
0
,
1
),
w1_blockscale
,
w1_alpha
,
w2_fp4
.
permute
(
2
,
0
,
1
),
a2_global_scale
,
w2_blockscale
,
w2_alpha
,
masked_m
.
to
(
hidden_states
.
device
),
)
# reference
a_fp4
,
a_scale_interleaved
=
fp4_quantize
(
hidden_states
,
input_global_scale
)
a_in_dtype
=
dequantize_nvfp4_to_dtype
(
a_fp4
,
a_scale_interleaved
,
input_global_scale
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
block_size
=
16
,
)
w1_d
=
torch
.
empty
(
(
num_experts
,
2
*
inter_dim
,
hidden_dim
),
device
=
w1
.
device
,
dtype
=
w1
.
dtype
,
)
w2_d
=
torch
.
empty
(
(
num_experts
,
hidden_dim
,
inter_dim
),
device
=
w2
.
device
,
dtype
=
w2
.
dtype
,
)
for
idx
in
range
(
0
,
num_experts
):
w1_fp4_sliced
,
w1_blockscale_sliced
=
fp4_quantize
(
w1
[
idx
],
w1_global_scale
[
idx
]
)
w2_fp4_sliced
,
w2_blockscale_sliced
=
fp4_quantize
(
w2
[
idx
],
w2_global_scale
[
idx
]
)
w1_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w1_fp4_sliced
,
w1_blockscale_sliced
,
w1_global_scale
[
idx
],
dtype
=
w1
.
dtype
,
device
=
w1
.
device
,
block_size
=
16
,
)
w2_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w2_fp4_sliced
,
w2_blockscale_sliced
,
w2_global_scale
[
idx
],
dtype
=
w2
.
dtype
,
device
=
w2
.
device
,
block_size
=
16
,
)
ref_output
=
torch_moe_nvfp4
(
a_in_dtype
,
w1_d
,
w2_d
,
topk
,
routing_weights
.
to
(
a_in_dtype
.
device
),
topk_idx
.
to
(
a_in_dtype
.
device
),
)
out_weighted
=
torch
.
zeros_like
(
ref_output
,
device
=
out
.
device
,
dtype
=
out
.
dtype
)
positions
=
torch
.
nonzero
(
masked_m
[
topk_idx
],
as_tuple
=
False
)
rows
,
cols
=
positions
[:,
0
],
positions
[:,
1
]
experts
=
topk_idx
[
rows
,
cols
]
for
i
in
range
(
num_experts
):
mask
=
experts
==
i
if
mask
.
any
():
idx
=
torch
.
nonzero
(
mask
,
as_tuple
=
False
).
squeeze
(
-
1
)
r
,
c
=
rows
[
idx
],
cols
[
idx
]
out_weighted
[
r
]
+=
out
[
i
,
:
len
(
r
),
:]
*
routing_weights
[
r
,
c
].
to
(
out
.
device
).
unsqueeze
(
-
1
)
torch
.
testing
.
assert_close
(
out_weighted
.
cpu
(),
ref_output
.
cpu
(),
atol
=
5e-2
,
rtol
=
5e-2
)
print
(
f
"Test passed with bs=
{
bs
}
, hidden_dim=
{
hidden_dim
}
, inter_dim=
{
inter_dim
}
, topk=
{
topk
}
"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_
cutedsl_flashinfer_8
gpu.py
→
test/srt/test_
deepseek_v3_cutedsl_4
gpu.py
View file @
0711d150
...
@@ -24,20 +24,31 @@ class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase):
...
@@ -24,20 +24,31 @@ class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase):
other_args
=
[
other_args
=
[
"--trust-remote-code"
,
"--trust-remote-code"
,
"--disable-radix-cache"
,
"--disable-radix-cache"
,
"--mem-fraction-static"
,
"0.89"
,
"--max-prefill-tokens"
,
"16384"
,
"--max-running-requests"
,
"--max-running-requests"
,
"256"
,
"256"
,
"--chunked-prefill-size"
,
"--chunked-prefill-size"
,
"
2048
"
,
"
1024
"
,
"--tp"
,
"--tp"
,
"
8
"
,
"
4
"
,
"--dp"
,
"--dp"
,
"8"
,
"4"
,
"--ep"
,
"4"
,
"--moe-dense-tp-size"
,
"1"
,
"--enable-dp-attention"
,
"--enable-dp-attention"
,
"--enable-ep-moe"
,
"--quantization"
,
"--quantization"
,
"modelopt_fp4"
,
"modelopt_fp4"
,
"--enable-flashinfer-cutedsl-moe"
,
"--attention-backend"
,
"--enable-deepep-moe"
,
"trtllm_mla"
,
"--moe-a2a-backend"
,
"deepep"
,
"--moe-runner-backend"
,
"flashinfer_cutedsl"
,
"--deepep-mode"
,
"--deepep-mode"
,
"low_latency"
,
"low_latency"
,
]
]
...
@@ -50,6 +61,7 @@ class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase):
...
@@ -50,6 +61,7 @@ class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase):
**
os
.
environ
,
**
os
.
environ
,
"SGLANG_DEEPEP_BF16_DISPATCH"
:
"1"
,
"SGLANG_DEEPEP_BF16_DISPATCH"
:
"1"
,
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
:
"256"
,
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
:
"256"
,
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH"
:
"0"
,
},
},
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment