Unverified Commit 63227acc authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Kernel] Add topk_sigmoid kernel (#31246)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent e675dda6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import torch
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
num_tokens_range = [2**i for i in range(0, 8, 2)]
num_experts_range = [16, 32, 64, 128, 256, 512]
topk_range = [3, 4]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
def torch_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
scoring_func: str = "softmax",
):
if scoring_func == "softmax":
scores = torch.softmax(gating_output.float(), dim=-1)
else:
scores = torch.sigmoid(gating_output.float())
topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def get_benchmark(scoring_func):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["torch", "vllm"],
line_names=["Torch", "vLLM"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name=f"fused-topk-perf-{scoring_func}",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
dtype = torch.bfloat16
hidden_size = 1024
renormalize = True
hidden_states = torch.randn(
(num_tokens, hidden_size), dtype=dtype, device="cuda"
)
gating_output = torch.randn(
(num_tokens, num_experts), dtype=dtype, device="cuda"
)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_topk(
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
scoring_func=scoring_func,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fused_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
scoring_func=scoring_func,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the MoE topk kernel.")
parser.add_argument("--scoring-func", type=str, default="softmax")
parser.add_argument("--save-path", type=str, default="./configs/fused_topk/")
args = parser.parse_args()
# Get the benchmark function
benchmark = get_benchmark(args.scoring_func)
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
......@@ -4,7 +4,13 @@
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output, bool renormalize);
torch::Tensor& gating_output, bool renormalize,
std::optional<torch::Tensor> bias);
void topk_sigmoid(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output, bool renormalize,
std::optional<torch::Tensor> bias);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
......
This diff is collapsed.
......@@ -5,9 +5,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()");
"token_expert_indices, Tensor gating_output, bool renormalize, Tensor? "
"bias) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Apply topk sigmoid to the gating outputs.
m.def(
"topk_sigmoid(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize, Tensor? "
"bias) -> ()");
m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid);
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MoE fused topk kernel
Run `pytest tests/kernels/moe/test_fused_topk.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
from vllm.platforms import current_platform
def torch_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor = None,
scoring_func: str = "softmax",
):
if scoring_func == "softmax":
scores = torch.softmax(gating_output.float(), dim=-1)
else:
assert scoring_func == "sigmoid"
scores = torch.sigmoid(gating_output.float())
if e_score_correction_bias is not None:
num_experts = gating_output.shape[-1]
scores_for_choice = scores.view(
-1, num_experts
) + e_score_correction_bias.unsqueeze(0)
_, topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1)
topk_weights = scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("num_tokens", [1, 33, 56])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [6, 16])
@pytest.mark.parametrize("topk", [3, 4])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_fused_topk(
num_tokens: int,
hidden_size: int,
num_experts: int,
topk: int,
renormalize: bool,
scoring_func: str,
dtype: torch.dtype,
):
torch.manual_seed(0)
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
topk_weights_ref, topk_ids_ref = torch_topk(
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
scoring_func=scoring_func,
)
topk_weights, topk_ids, _ = fused_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
scoring_func=scoring_func,
)
torch.testing.assert_close(
topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(topk_ids_ref.to(torch.int32), topk_ids, atol=0, rtol=0)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("num_tokens", [1, 33, 56])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [6, 16])
@pytest.mark.parametrize("topk", [3, 4])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_fused_topk_bias(
num_tokens: int,
hidden_size: int,
num_experts: int,
topk: int,
renormalize: bool,
scoring_func: str,
dtype: torch.dtype,
):
torch.manual_seed(0)
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
e_score_correction_bias = torch.randn(
(num_experts,), dtype=torch.float32, device="cuda"
)
topk_weights_ref, topk_ids_ref = torch_topk(
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
scoring_func=scoring_func,
)
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=gating_output,
e_score_correction_bias=e_score_correction_bias,
topk=topk,
renormalize=renormalize,
scoring_func=scoring_func,
)
torch.testing.assert_close(
topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(topk_ids_ref.to(torch.int32), topk_ids, atol=0, rtol=0)
......@@ -18,7 +18,9 @@ from vllm.model_executor.layers.activation import (
SiluAndMul,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
dispatch_topk_func,
dispatch_topk_sigmoid_func,
dispatch_topk_softmax_func,
vllm_topk_sigmoid,
vllm_topk_softmax,
)
from vllm.model_executor.layers.layernorm import (
......@@ -133,8 +135,8 @@ def test_enabled_ops_invalid(env: str):
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_topk_dispatch(use_rocm_aiter: bool):
topk_func = dispatch_topk_func(use_rocm_aiter)
def test_topk_softmax_dispatch(use_rocm_aiter: bool):
topk_func = dispatch_topk_softmax_func(use_rocm_aiter)
if current_platform.is_rocm() and use_rocm_aiter:
assert topk_func == rocm_aiter_ops.topk_softmax
......@@ -142,6 +144,18 @@ def test_topk_dispatch(use_rocm_aiter: bool):
assert topk_func == vllm_topk_softmax
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
topk_func = dispatch_topk_sigmoid_func(use_rocm_aiter)
if current_platform.is_rocm() and use_rocm_aiter:
assert topk_func == rocm_aiter_ops.topk_sigmoid
else:
assert topk_func == vllm_topk_sigmoid
@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
......
......@@ -200,6 +200,24 @@ def _rocm_aiter_topk_softmax_fake(
pass
def _rocm_aiter_topk_sigmoid_impl(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
gating_output: torch.Tensor,
) -> None:
from aiter import topk_sigmoid
topk_sigmoid(topk_weights, topk_indices, gating_output)
def _rocm_aiter_topk_sigmoid_fake(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
gating_output: torch.Tensor,
) -> None:
pass
def _rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
......@@ -985,6 +1003,14 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_sigmoid",
op_func=_rocm_aiter_topk_sigmoid_impl,
mutates_args=["topk_weights", "topk_indices"],
fake_impl=_rocm_aiter_topk_sigmoid_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=_rocm_aiter_biased_grouped_topk_impl,
......@@ -1272,6 +1298,19 @@ class rocm_aiter_ops:
)
return topk_weights, topk_indices
@staticmethod
def topk_sigmoid(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_sigmoid(
topk_weights, topk_indices, gating_output
)
return topk_weights, topk_indices
@staticmethod
def biased_grouped_topk(
gating_output: torch.Tensor,
......
......@@ -2177,9 +2177,33 @@ def topk_softmax(
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> None:
torch.ops._moe_C.topk_softmax(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
def topk_sigmoid(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> None:
torch.ops._moe_C.topk_sigmoid(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
......
......@@ -106,14 +106,14 @@ def _quant_flags_to_group_shape(
class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = (0,)
# Renormalize: TopK -> Softmax
# Renormalize: TopK -> Softmax/Sigmoid
Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid
Llama4 = (3,)
# RenormalizeNaive: Softmax -> TopK -> Renormalize
# RenormalizeNaive: Softmax/Sigmoid -> TopK -> Renormalize
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)
......
......@@ -4,6 +4,8 @@ from collections.abc import Callable
import torch
import vllm._custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
......@@ -12,15 +14,106 @@ from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
def vllm_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_indices
def vllm_topk_sigmoid(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, ...]:
ops.topk_sigmoid(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_indices
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
scoring_func: str = "softmax",
indices_type: torch.dtype | None = None,
):
if not rocm_aiter_ops.is_fused_moe_enabled():
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch"
)
M, _ = hidden_states.size()
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
if scoring_func == "softmax":
topk_weights, topk_ids = vllm_topk_softmax(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_ids
elif scoring_func == "sigmoid":
topk_weights, topk_ids = vllm_topk_sigmoid(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_ids
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
if scoring_func == "softmax":
scores = gating_output.softmax(dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
......@@ -43,6 +136,7 @@ class FusedTopKBiasRouter(BaseRouter):
global_num_experts: int,
eplb_state: EplbLayerState,
e_score_correction_bias: torch.Tensor,
scoring_func: str,
renormalize: bool = True,
routed_scaling_factor: float = 1.0,
enable_eplb: bool = False,
......@@ -57,6 +151,7 @@ class FusedTopKBiasRouter(BaseRouter):
)
self.e_score_correction_bias = e_score_correction_bias
self.renormalize = renormalize
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
@property
......@@ -80,6 +175,7 @@ class FusedTopKBiasRouter(BaseRouter):
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
scoring_func=self.scoring_func,
)
if self.routed_scaling_factor != 1.0:
......
......@@ -16,7 +16,7 @@ def vllm_topk_softmax(
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
renormalize: bool = False,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
......@@ -29,7 +29,25 @@ def vllm_topk_softmax(
return topk_weights, topk_indices
def dispatch_topk_func(
def vllm_topk_sigmoid(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
) -> tuple[torch.Tensor, ...]:
ops.topk_sigmoid(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
return topk_weights, topk_indices
def dispatch_topk_softmax_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
......@@ -37,12 +55,21 @@ def dispatch_topk_func(
return vllm_topk_softmax
def dispatch_topk_sigmoid_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_sigmoid
return vllm_topk_sigmoid
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: torch.dtype | None = None,
scoring_func: str = "softmax",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
......@@ -61,12 +88,26 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
if scoring_func == "softmax":
topk_func = dispatch_topk_softmax_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
elif scoring_func == "sigmoid":
topk_func = dispatch_topk_sigmoid_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
return topk_weights, topk_ids, token_expert_indices
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
class FusedTopKRouter(BaseRouter):
......@@ -82,7 +123,6 @@ class FusedTopKRouter(BaseRouter):
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
assert scoring_func == "softmax", "FusedTopKRouter only supports softmax."
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
......@@ -91,6 +131,7 @@ class FusedTopKRouter(BaseRouter):
indices_type_getter=indices_type_getter,
)
self.renormalize = renormalize
self.scoring_func = scoring_func
@property
def routing_method_type(self) -> RoutingMethodType:
......@@ -113,6 +154,7 @@ class FusedTopKRouter(BaseRouter):
topk=self.top_k,
renormalize=self.renormalize,
indices_type=indices_type,
scoring_func=self.scoring_func,
)
return topk_weights, topk_ids
......@@ -143,17 +143,13 @@ def create_fused_moe_router(
router.capture = capture
return router
if scoring_func != "softmax":
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."
)
if e_score_correction_bias is not None:
router = FusedTopKBiasRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
e_score_correction_bias=e_score_correction_bias,
scoring_func=scoring_func,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,
......
......@@ -100,9 +100,6 @@ class MiniMaxM2MoE(nn.Module):
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
scoring_func=config.scoring_func,
use_grouped_topk=True,
num_expert_group=1,
topk_group=1,
e_score_correction_bias=self.e_score_correction_bias,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment