Unverified Commit 0711d150 authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Fix cutedsl backend of MoE (#12353)

parent 09938e1f
...@@ -821,7 +821,7 @@ jobs: ...@@ -821,7 +821,7 @@ jobs:
python3 run_suite.py --suite per-commit-4-gpu-b200 --auto-partition-id 0 --auto-partition-size 1 --timeout-per-file 3600 python3 run_suite.py --suite per-commit-4-gpu-b200 --auto-partition-id 0 --auto-partition-size 1 --timeout-per-file 3600
unit-test-backend-4-gpu-gb200: unit-test-backend-4-gpu-gb200:
needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels-arm] needs: [check-changes, sgl-kernel-build-wheels-arm]
if: always() && !failure() && !cancelled() && if: always() && !failure() && !cancelled() &&
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true')) ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
runs-on: 4-gpu-gb200 runs-on: 4-gpu-gb200
...@@ -841,7 +841,7 @@ jobs: ...@@ -841,7 +841,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} IS_BLACKWELL=1 bash scripts/ci/ci_install_dependency.sh CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} IS_BLACKWELL=1 GRACE_BLACKWELL=1 bash scripts/ci/ci_install_deepep.sh
- name: Run test - name: Run test
timeout-minutes: 45 timeout-minutes: 45
......
from typing import Optional, Union from typing import Optional
import torch import torch
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
...@@ -20,7 +20,7 @@ def get_cute_dtype(input: torch.Tensor) -> str: ...@@ -20,7 +20,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
def flashinfer_cutedsl_moe_masked( def flashinfer_cutedsl_moe_masked(
hidden_states: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], hidden_states: tuple[torch.Tensor, Optional[torch.Tensor]],
input_global_scale: torch.Tensor, input_global_scale: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w1_blockscale: torch.Tensor, w1_blockscale: torch.Tensor,
...@@ -40,7 +40,7 @@ def flashinfer_cutedsl_moe_masked( ...@@ -40,7 +40,7 @@ def flashinfer_cutedsl_moe_masked(
Args: Args:
hidden_states: Either of the following case hidden_states: Either of the following case
* torch.Tensor: [num_experts, m, k], bf16 * tuple[torch.Tensor, None]: [num_experts, m, k], bf16, None means no quant
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn
input_global_scale (torch.Tensor): (l,) input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
...@@ -74,21 +74,21 @@ def flashinfer_cutedsl_moe_masked( ...@@ -74,21 +74,21 @@ def flashinfer_cutedsl_moe_masked(
assert ( assert (
w2_alpha.dtype == torch.float32 w2_alpha.dtype == torch.float32
), f"w2_alpha must be float32, got {w2_alpha.dtype}" ), f"w2_alpha must be float32, got {w2_alpha.dtype}"
assert (
len(hidden_states) == 2
), f"hidden_states must be a tuple of length 2, got {len(hidden_states)}"
# === Assertions on shapes === # === Assertions on shapes ===
n = w2.shape[-1] * 2 # intermediate dimension n = w2.shape[-1] * 2 # intermediate dimension
if isinstance(hidden_states, tuple): if hidden_states[1] is not None:
assert (
input_global_scale is None
), "input_global_scale is needed when input needs quant"
a_q = hidden_states[0].view(torch.uint8) a_q = hidden_states[0].view(torch.uint8)
a_q_sf = hidden_states[1].view(torch.float8_e4m3fn) a_q_sf = hidden_states[1].view(torch.float8_e4m3fn)
m, k_by_2, num_experts = a_q.shape m, k_by_2, num_experts = a_q.shape
k = k_by_2 * 2 k = k_by_2 * 2
else: else:
num_experts, m, k = hidden_states.shape num_experts, m, k = hidden_states[0].shape
assert ( assert (
input_global_scale.dtype == torch.float32 input_global_scale.dtype == torch.float32
...@@ -98,7 +98,7 @@ def flashinfer_cutedsl_moe_masked( ...@@ -98,7 +98,7 @@ def flashinfer_cutedsl_moe_masked(
), f"input_global_scale must be (l,), got {input_global_scale.shape}" ), f"input_global_scale must be (l,), got {input_global_scale.shape}"
a_q, a_q_sf = scaled_fp4_grouped_quant( a_q, a_q_sf = scaled_fp4_grouped_quant(
hidden_states, hidden_states[0],
input_global_scale, input_global_scale,
masked_m, masked_m,
) )
......
...@@ -1451,7 +1451,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1451,7 +1451,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) )
layer.dispatcher.set_quant_config( layer.dispatcher.set_quant_config(
{"input_global_scale": layer.w13_input_scale_quant} {
"input_global_scale": (
layer.w13_input_scale_quant if CUTEDSL_MOE_NVFP4_DISPATCH else None
)
}
) )
# Validate weight scales # Validate weight scales
...@@ -1688,7 +1692,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1688,7 +1692,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
def apply_without_routing_weights( def apply_without_routing_weights(
self, self,
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: tuple[torch.Tensor, Optional[torch.Tensor]],
masked_m: torch.Tensor, masked_m: torch.Tensor,
moe_runner_config: MoeRunnerConfig, moe_runner_config: MoeRunnerConfig,
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"], down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
......
...@@ -57,7 +57,7 @@ DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test" ...@@ -57,7 +57,7 @@ DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = "lmsys/sglang-ci-dsv3-test-NextN" DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = "lmsys/sglang-ci-dsv3-test-NextN"
# NVFP4 models # NVFP4 models
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST = "nvidia/DeepSeek-R1-0528-FP4" DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST = "nvidia/DeepSeek-V3-0324-FP4"
# FP8 models # FP8 models
DEFAULT_MODEL_NAME_FOR_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8" DEFAULT_MODEL_NAME_FOR_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
......
...@@ -10,9 +10,20 @@ export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH" ...@@ -10,9 +10,20 @@ export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
export PATH="${NVSHMEM_DIR}/bin:$PATH" export PATH="${NVSHMEM_DIR}/bin:$PATH"
export CUDA_HOME=/usr/local/cuda export CUDA_HOME=/usr/local/cuda
if python3 -c "import deep_ep" >/dev/null 2>&1; then GRACE_BLACKWELL=${GRACE_BLACKWELL:-0}
echo "deep_ep is already installed or importable. Skipping installation." # Detect architecture
exit 0 ARCH=$(uname -m)
if [ "$ARCH" != "x86_64" ] && [ "$ARCH" != "aarch64" ]; then
echo "Unsupported architecture: $ARCH"
exit 1
fi
# It seems GB200 ci runner preinstalls some wrong version of deep_ep, so we cannot rely on it.
if [ "$GRACE_BLACKWELL" != "1" ]; then
if python3 -c "import deep_ep" >/dev/null 2>&1; then
echo "deep_ep is already installed or importable. Skipping installation."
exit 0
fi
fi fi
# Install system dependencies # Install system dependencies
...@@ -35,8 +46,10 @@ dpkg -i libgdrapi_*.deb ...@@ -35,8 +46,10 @@ dpkg -i libgdrapi_*.deb
dpkg -i gdrcopy-tests_*.deb dpkg -i gdrcopy-tests_*.deb
dpkg -i gdrcopy_*.deb dpkg -i gdrcopy_*.deb
if [ ! -e "/usr/lib/x86_64-linux-gnu/libmlx5.so" ]; then # Set up library paths based on architecture
ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so LIB_PATH="/usr/lib/$ARCH-linux-gnu"
if [ ! -e "$LIB_PATH/libmlx5.so" ]; then
ln -s $LIB_PATH/libmlx5.so.1 $LIB_PATH/libmlx5.so
fi fi
apt-get update && apt-get install -y libfabric-dev apt-get update && apt-get install -y libfabric-dev
...@@ -45,6 +58,11 @@ cd /opt/nvshmem ...@@ -45,6 +58,11 @@ cd /opt/nvshmem
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.4.5/source/nvshmem_src_cuda12-all-all-3.4.5.tar.gz wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.4.5/source/nvshmem_src_cuda12-all-all-3.4.5.tar.gz
tar -xf nvshmem_src_cuda12-all-all-3.4.5.tar.gz tar -xf nvshmem_src_cuda12-all-all-3.4.5.tar.gz
mv nvshmem_src nvshmem && cd nvshmem mv nvshmem_src nvshmem && cd nvshmem
if [ "$GRACE_BLACKWELL" = "1" ]; then
CUDA_ARCH="100;120"
else
CUDA_ARCH="90"
fi
NVSHMEM_SHMEM_SUPPORT=0 \ NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \ NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \ NVSHMEM_USE_NCCL=0 \
...@@ -53,13 +71,45 @@ NVSHMEM_IBGDA_SUPPORT=1 \ ...@@ -53,13 +71,45 @@ NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \ NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \ NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/opt/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90 cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/opt/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCH}
cd build cd build
make -j$(nproc) install make -j$(nproc) install
# Install DeepEP # Install DeepEP
rm -rf /root/.cache/deepep && git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep && cd /root/.cache/deepep && git checkout 9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee DEEPEP_DIR=/root/.cache/deepep
cd /root/.cache/deepep && python3 setup.py install rm -rf ${DEEPEP_DIR}
if [ "$GRACE_BLACKWELL" = "1" ]; then
# We use Tom's DeepEP fork for GB200 for now, which supports fp4 dispatch.
GRACE_BLACKWELL_DEEPEP_BRANCH=gb200_blog_part_2
git clone https://github.com/fzyzcjy/DeepEP.git ${DEEPEP_DIR} && \
pushd ${DEEPEP_DIR} && \
git checkout ${GRACE_BLACKWELL_DEEPEP_BRANCH} && \
sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh && \
popd
else
git clone https://github.com/deepseek-ai/DeepEP.git ${DEEPEP_DIR} && \
pushd ${DEEPEP_DIR} && \
git checkout 9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee && \
popd
fi
cd ${DEEPEP_DIR}
if [ "$GRACE_BLACKWELL" = "1" ]; then
CUDA_VERSION=$(nvidia-smi | grep "CUDA Version" | head -n1 | awk '{print $9}')
if [ "$CUDA_VERSION" = "12.8" ]; then
CHOSEN_TORCH_CUDA_ARCH_LIST='10.0'
elif awk -v ver="$CUDA_VERSION" 'BEGIN {exit !(ver > 12.8)}'; then
CHOSEN_TORCH_CUDA_ARCH_LIST='10.0;10.3'
else
echo "Unsupported CUDA version for Grace Blackwell: $CUDA_VERSION" && exit 1
fi && \
if [ "${CUDA_VERSION%%.*}" = "13" ]; then \
sed -i "/^ include_dirs = \['csrc\/'\]/a\ include_dirs.append('${CUDA_HOME}/include/cccl')" setup.py; \
fi
NVSHMEM_DIR=/opt/nvshmem/install TORCH_CUDA_ARCH_LIST="${CHOSEN_TORCH_CUDA_ARCH_LIST}" pip install --no-build-isolation .
else
python3 setup.py install
fi
# Verify configuration # Verify configuration
echo "=== Verify NVSHMEM ===" echo "=== Verify NVSHMEM ==="
......
...@@ -179,7 +179,10 @@ suites = { ...@@ -179,7 +179,10 @@ suites = {
TestFile("test_llama31_fp4.py", 300), TestFile("test_llama31_fp4.py", 300),
], ],
"per-commit-4-gpu-gb200": [ "per-commit-4-gpu-gb200": [
TestFile("test_cutedsl_moe.py", 300),
TestFile("test_deepseek_v3_fp4_4gpu.py", 3600), TestFile("test_deepseek_v3_fp4_4gpu.py", 3600),
# Disabled temporarily, see https://github.com/sgl-project/sglang/issues/12533
# TestFile("test_deepseek_v3_cutedsl_4gpu.py", 3600),
], ],
"per-commit-4-gpu-deepep": [ "per-commit-4-gpu-deepep": [
TestFile("ep/test_deepep_small.py", 531), TestFile("ep/test_deepep_small.py", 531),
......
# SPDX-License-Identifier: Apache-2.0
import unittest
from typing import Callable
import torch
from flashinfer import fp4_quantize
from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant
from torch.nn import functional as F
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import flashinfer_cutedsl_moe_masked
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
SKIP_TEST = torch.cuda.get_device_capability() < (10, 0)
SKIP_REASON = "Nvfp4 Requires compute capability of 10 or above."
kE2M1ToFloat = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)
FLOAT8_E4M3_MAX = 448.0
FLOAT4_E2M1_MAX = 6.0
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype=dtype)
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles
# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()
# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)
def compute_routing(router_logits: torch.Tensor, top_k: int):
routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.float()
return routing_weights, selected_experts
def prepare_inputs(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
num_experts: int,
topk: int,
):
routing_weights, topk_idx = compute_routing(router_logits, topk)
masked_m = []
for i in range(num_experts):
mask = topk_idx.view(-1) == i
masked_m.append(mask.sum())
masked_m = torch.tensor(masked_m, dtype=torch.int32)
hidden_states_3d = torch.empty(
(num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
)
for i in range(num_experts):
hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
return hidden_states_3d, masked_m, topk_idx, routing_weights
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1024),
(224, 1024, 1024),
(224, 1024, 1536),
]
# Reference implementation of torch_moe
def torch_moe(a, w1, w2, score, topk, expert_map):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
m = w1[i].shape[0]
assert m % 2 == 0
# Note: w1 and w3 are swapped!
w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
inter_gs = torch.tensor(1.0).cuda()
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
inter = dequantize_nvfp4_to_dtype(
inter_q,
inter_blockscale,
inter_gs,
dtype=inter.dtype,
device=inter.device,
block_size=16,
).cuda()
out[mask] = inter @ w2[i].transpose(0, 1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def check_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
moe_impl: Callable,
flip_w13: bool,
):
torch.manual_seed(7)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
sf_w1_2n = round_up(2 * n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty(
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
)
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
sf_w2_k = round_up(k, 128)
sf_w2_n = round_up(n // quant_blocksize, 4)
w2_blockscale = torch.empty(
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
)
w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_amax = torch.abs(w1).max().to(torch.float32)
w2_amax = torch.abs(w2).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = scaled_fp4_quant(
w1[expert], w1_gs[expert]
)
w2_q[expert], w2_blockscale[expert] = scaled_fp4_quant(
w2[expert], w2_gs[expert]
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
topk_config=TopKConfig(top_k=topk, renormalize=False),
)
topk_weights, topk_ids, _ = topk_output
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
test_output = moe_impl(
a=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1_q=w1_q,
w2_q=w2_q,
a1_gs=a1_gs,
w1_blockscale=w1_blockscale,
w1_alphas=(1 / w1_gs),
a2_gs=a2_gs,
w2_blockscale=w2_blockscale,
w2_alphas=(1 / w2_gs),
)
# Reference check:
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
).to(torch.float32)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize,
)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=w1.dtype,
device=w1.device,
block_size=quant_blocksize,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=w2.dtype,
device=w2.device,
block_size=quant_blocksize,
)
if flip_w13:
dim = -2
size = w1_d.size(dim)
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
half = size // 2
# Reorder weight
w1, w3 = w1_d.split(half, dim=dim)
w1_d = torch.cat([w3, w1], dim=dim).contiguous()
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
torch.testing.assert_close(torch_output, test_output, atol=1e-1, rtol=1e-1)
class TestFlashinferCutedslMoe(unittest.TestCase):
@unittest.skipIf(SKIP_TEST, SKIP_REASON)
def test_flashinfer_cutedsl_moe_masked(self):
# Test parameters
test_cases = [
(2, 128, 256, 1),
(2, 128, 256, 2),
(2, 128, 256, 4),
(16, 128, 512, 1),
(16, 128, 512, 2),
(16, 128, 512, 4),
]
for bs, hidden_dim, inter_dim, topk in test_cases:
with self.subTest(
bs=bs, hidden_dim=hidden_dim, inter_dim=inter_dim, topk=topk
):
print(
f"Testing with bs={bs}, hidden_dim={hidden_dim}, inter_dim={inter_dim}, topk={topk}"
)
with torch.inference_mode():
torch.manual_seed(42)
device = "cuda"
dtype = torch.bfloat16
num_experts = 8
hidden_states = (
torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device)
/ 5.0
)
w1 = (
torch.randn(
num_experts,
2 * inter_dim,
hidden_dim,
dtype=torch.bfloat16,
device=device,
)
/ 10.0
)
w2 = (
torch.randn(
num_experts,
hidden_dim,
inter_dim,
dtype=torch.bfloat16,
device=device,
)
/ 10.0
)
router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
hidden_states_expanded = (
hidden_states.view(bs, -1, hidden_dim)
.repeat(1, topk, 1)
.reshape(-1, hidden_dim)
)
hidden_states_3d, masked_m, topk_idx, routing_weights = (
prepare_inputs(
hidden_states_expanded, router_logits, num_experts, topk
)
)
w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
input_global_scale = torch.ones(
(num_experts,), dtype=torch.float32, device=hidden_states.device
)
w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
a2_global_scale = torch.ones(
(num_experts,), dtype=torch.float32, device=hidden_states.device
) # assume intermediate scale is 1.0
w1_fp4, w1_blockscale = scaled_fp4_grouped_quant(
w1,
w1_global_scale,
torch.ones(num_experts, dtype=torch.int32, device=w1.device)
* 2
* inter_dim,
)
w2_fp4, w2_blockscale = scaled_fp4_grouped_quant(
w2,
w2_global_scale,
torch.ones(num_experts, dtype=torch.int32, device=w2.device)
* hidden_dim,
)
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
out = flashinfer_cutedsl_moe_masked(
(hidden_states_3d.to(hidden_states.device), None),
input_global_scale,
w1_fp4.permute(2, 0, 1),
w1_blockscale,
w1_alpha,
w2_fp4.permute(2, 0, 1),
a2_global_scale,
w2_blockscale,
w2_alpha,
masked_m.to(hidden_states.device),
)
# reference
a_fp4, a_scale_interleaved = fp4_quantize(
hidden_states, input_global_scale
)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
input_global_scale,
dtype=hidden_states.dtype,
device=hidden_states.device,
block_size=16,
)
w1_d = torch.empty(
(num_experts, 2 * inter_dim, hidden_dim),
device=w1.device,
dtype=w1.dtype,
)
w2_d = torch.empty(
(num_experts, hidden_dim, inter_dim),
device=w2.device,
dtype=w2.dtype,
)
for idx in range(0, num_experts):
w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
w1[idx], w1_global_scale[idx]
)
w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
w2[idx], w2_global_scale[idx]
)
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_fp4_sliced,
w1_blockscale_sliced,
w1_global_scale[idx],
dtype=w1.dtype,
device=w1.device,
block_size=16,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_fp4_sliced,
w2_blockscale_sliced,
w2_global_scale[idx],
dtype=w2.dtype,
device=w2.device,
block_size=16,
)
ref_output = torch_moe_nvfp4(
a_in_dtype,
w1_d,
w2_d,
topk,
routing_weights.to(a_in_dtype.device),
topk_idx.to(a_in_dtype.device),
)
out_weighted = torch.zeros_like(
ref_output, device=out.device, dtype=out.dtype
)
positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
rows, cols = positions[:, 0], positions[:, 1]
experts = topk_idx[rows, cols]
for i in range(num_experts):
mask = experts == i
if mask.any():
idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
r, c = rows[idx], cols[idx]
out_weighted[r] += out[i, : len(r), :] * routing_weights[
r, c
].to(out.device).unsqueeze(-1)
torch.testing.assert_close(
out_weighted.cpu(), ref_output.cpu(), atol=5e-2, rtol=5e-2
)
print(
f"Test passed with bs={bs}, hidden_dim={hidden_dim}, inter_dim={inter_dim}, topk={topk}"
)
if __name__ == "__main__":
unittest.main()
...@@ -24,20 +24,31 @@ class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase): ...@@ -24,20 +24,31 @@ class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase):
other_args = [ other_args = [
"--trust-remote-code", "--trust-remote-code",
"--disable-radix-cache", "--disable-radix-cache",
"--mem-fraction-static",
"0.89",
"--max-prefill-tokens",
"16384",
"--max-running-requests", "--max-running-requests",
"256", "256",
"--chunked-prefill-size", "--chunked-prefill-size",
"2048", "1024",
"--tp", "--tp",
"8", "4",
"--dp", "--dp",
"8", "4",
"--ep",
"4",
"--moe-dense-tp-size",
"1",
"--enable-dp-attention", "--enable-dp-attention",
"--enable-ep-moe",
"--quantization", "--quantization",
"modelopt_fp4", "modelopt_fp4",
"--enable-flashinfer-cutedsl-moe", "--attention-backend",
"--enable-deepep-moe", "trtllm_mla",
"--moe-a2a-backend",
"deepep",
"--moe-runner-backend",
"flashinfer_cutedsl",
"--deepep-mode", "--deepep-mode",
"low_latency", "low_latency",
] ]
...@@ -50,6 +61,7 @@ class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase): ...@@ -50,6 +61,7 @@ class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase):
**os.environ, **os.environ,
"SGLANG_DEEPEP_BF16_DISPATCH": "1", "SGLANG_DEEPEP_BF16_DISPATCH": "1",
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256", "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256",
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH": "0",
}, },
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment