Unverified Commit 8a3cd90a authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Kernel] Add fused grouped_topk kernel for MoE (#23274)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent 2a167b2e
......@@ -817,7 +817,9 @@ set(VLLM_MOE_EXT_SRC
"csrc/moe/topk_softmax_kernels.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu")
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
......
This diff is collapsed.
......@@ -22,6 +22,11 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor num_tokens_post_pad, int64_t top_k,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit);
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
double routed_scaling_factor);
#endif
bool moe_permute_unpermute_supported();
......
......@@ -78,6 +78,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"output_tensor) -> ()");
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
// Apply grouped topk routing to select experts.
m.def(
"grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int "
"topk_group, int topk, bool renormalize, float "
"routed_scaling_factor) -> (Tensor, Tensor)");
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
#endif
}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MoE grouped topk kernel
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk,
grouped_topk)
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
@pytest.mark.parametrize("n_token", [1, 33, 64])
@pytest.mark.parametrize("n_hidden", [1024, 2048])
@pytest.mark.parametrize("n_expert", [16])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("num_expert_group", [8])
@pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
n_hidden: int, n_expert: int, topk: int,
renormalize: bool, num_expert_group: int,
topk_group: int, scoring_func: str,
routed_scaling_factor: float, dtype: torch.dtype):
current_platform.seed_everything(0)
hidden_states = torch.randn((n_token, n_hidden),
dtype=dtype,
device="cuda")
gating_output = torch.randn((n_token, n_expert),
dtype=dtype,
device="cuda")
e_score_correction_bias = torch.randn((n_expert, ),
dtype=torch.float32,
device="cuda")
with monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
baseline_topk_weights, baseline_topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
test_topk_weights, test_topk_ids = fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
if renormalize:
torch.testing.assert_close(baseline_topk_weights,
test_topk_weights,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_topk_ids,
test_topk_ids,
atol=0,
rtol=0)
......@@ -1502,6 +1502,17 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
gating_output)
def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor,
num_expert_group: int, topk_group: int, topk: int,
renormalize: bool, routed_scaling_factor: float):
if not current_platform.is_cuda():
raise NotImplementedError("The fused grouped_topk kernel is only "
"available on CUDA platforms")
return torch.ops._moe_C.grouped_topk(scores, scores_with_bias,
num_expert_group, topk_group, topk,
renormalize, routed_scaling_factor)
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor,
b_bias: Optional[torch.Tensor],
......
......@@ -131,6 +131,7 @@ if TYPE_CHECKING:
VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
......@@ -963,6 +964,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_SKIP_DEEP_GEMM_WARMUP":
lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))),
# Whether to use fused grouped_topk used for MoE expert selection.
"VLLM_USE_FUSED_MOE_GROUPED_TOPK":
lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))),
# Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
......@@ -1229,6 +1234,7 @@ def compute_hash() -> str:
"VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM",
"VLLM_USE_TRTLLM_FP4_GEMM",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP8",
"VLLM_USE_FLASHINFER_MOE_FP4",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
......
......@@ -949,8 +949,23 @@ def grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \
current_platform.is_cuda() and \
num_expert_group <= 32 and topk <= 32 and \
e_score_correction_bias is not None:
return fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor)
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
......@@ -996,9 +1011,38 @@ def grouped_topk(
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
topk_values, topk_indices = ops.grouped_topk(
scores, scores_with_bias.to(scores.dtype), num_expert_group,
topk_group, topk, renormalize, routed_scaling_factor)
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
def get_config_dtype_str(
dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment