Unverified Commit 08f425ba authored by Xinyu Chen's avatar Xinyu Chen Committed by GitHub
Browse files

CustomOp: test forward dispatch for grouped_topk (#31530)


Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
parent a01f2fae
...@@ -8,6 +8,12 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`. ...@@ -8,6 +8,12 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
import pytest import pytest
import torch import torch
from vllm.config import (
CompilationConfig,
VllmConfig,
get_cached_compilation_config,
set_current_vllm_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk, GroupedTopk,
fused_grouped_topk, fused_grouped_topk,
...@@ -41,6 +47,11 @@ def test_grouped_topk( ...@@ -41,6 +47,11 @@ def test_grouped_topk(
routed_scaling_factor: float, routed_scaling_factor: float,
dtype: torch.dtype, dtype: torch.dtype,
): ):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(custom_ops=["all", "+grouped_topk"])
)
get_cached_compilation_config.cache_clear()
current_platform.seed_everything(0) current_platform.seed_everything(0)
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda") hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda")
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda") gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
...@@ -48,7 +59,7 @@ def test_grouped_topk( ...@@ -48,7 +59,7 @@ def test_grouped_topk(
(n_expert,), dtype=torch.float32, device="cuda" (n_expert,), dtype=torch.float32, device="cuda"
) )
with monkeypatch.context() as m: with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
grouped_topk = GroupedTopk( grouped_topk = GroupedTopk(
topk=topk, topk=topk,
...@@ -58,6 +69,7 @@ def test_grouped_topk( ...@@ -58,6 +69,7 @@ def test_grouped_topk(
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
assert grouped_topk._forward_method.__name__ == "forward_cuda"
baseline_topk_weights, baseline_topk_ids = grouped_topk( baseline_topk_weights, baseline_topk_ids = grouped_topk(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=gating_output, gating_output=gating_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment