Commit 0da93439 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.1rc0' into v0.18.1rc0-ori

parents 25f2f756 298e5108
model_name: "deepseek-ai/DeepSeek-V3.2"
accuracy_threshold: 0.95
num_questions: 1319
num_fewshot: 5
startup_max_wait_seconds: 1200
server_args: >-
--enforce-eager
--max-model-len 4096
--tensor-parallel-size 8
--enable-expert-parallel
--attention-backend=TRITON_ATTN
--speculative-config '{"method":"mtp","num_speculative_tokens":3}'
DeepSeek-R1-TP_MI325.yaml
DeepSeek-R1-DP_MI325.yaml
DeepSeek-V3.2-TP_MI325.yaml
DeepSeek-V3.2-DP_MI325.yaml
Qwen3.5-35B-A3B-DEP2.yaml
Qwen3.5-35B-A3B-FP8-DEP2.yaml
model_name: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4"
accuracy_threshold: 0.29
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=cutlass"
......@@ -15,3 +15,4 @@ Mixtral-8x7B-BF16-fi-cutlass.yaml
Mixtral-8x7B-BF16-triton.yaml
Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml
Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml
Nemotron-Nano-30B-NvFp4-ModelOpt-vllm-cutlass.yaml
......@@ -64,6 +64,16 @@ def test_gsm8k_correctness(config_filename):
"Marlin kernels are not supported."
)
# TODO(akaratza): Enable DeepSeek-V3.2 and DeepSeek-R1 on ROCm platforms
if current_platform.is_rocm() and (
"deepseek-ai/DeepSeek-V3.2" in eval_config["model_name"]
or "deepseek-ai/DeepSeek-R1" in eval_config["model_name"]
):
pytest.skip(
"Skipping DeepSeek-V3.2 and DeepSeek-R1 on ROCm platforms "
"due to agent pool disk space issues and pod evictions."
)
# Parse server arguments from config (use shlex to handle quoted strings)
server_args_str = eval_config.get("server_args", "")
server_args = shlex.split(server_args_str) if server_args_str else []
......
......@@ -14,8 +14,19 @@ from vllm.config import (
)
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
# CudaPlatform and RocmPlatform import their respective compiled C extensions
# at module level, raising ModuleNotFoundError on incompatible builds.
try:
from vllm.platforms.cuda import CudaPlatform
except (ImportError, ModuleNotFoundError):
CudaPlatform = None
try:
from vllm.platforms.rocm import RocmPlatform
except (ImportError, ModuleNotFoundError):
RocmPlatform = None
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
......@@ -101,6 +112,8 @@ def test_backend_selection(
assert backend.get_name() == "CPU_ATTN"
elif device == "hip":
if RocmPlatform is None:
pytest.skip("RocmPlatform not available")
with patch("vllm.platforms.current_platform", RocmPlatform()):
if use_mla:
# ROCm MLA backend logic:
......@@ -126,6 +139,8 @@ def test_backend_selection(
assert backend.get_name() == expected
elif device == "cuda":
if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with patch("vllm.platforms.current_platform", CudaPlatform()):
capability = torch.cuda.get_device_capability()
if use_mla:
......@@ -214,7 +229,7 @@ def test_backend_selection(
assert backend.get_name() == expected
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", ["cpu", "cuda", "hip"])
def test_fp32_fallback(device: str):
"""Test attention backend selection with fp32."""
# Use default config (no backend specified)
......@@ -227,10 +242,25 @@ def test_fp32_fallback(device: str):
assert backend.get_name() == "CPU_ATTN"
elif device == "cuda":
if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with patch("vllm.platforms.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None)
assert backend.get_name() == "FLEX_ATTENTION"
elif device == "hip":
if RocmPlatform is None:
pytest.skip("RocmPlatform not available")
# ROCm backends do not support head_size=16 (minimum is 32).
# No known HuggingFace transformer model uses head_size=16.
# Revisit if a real model with this head size is identified
# and accuracy-tested.
with (
patch("vllm.platforms.current_platform", RocmPlatform()),
pytest.raises(ValueError, match="No valid attention backend"),
):
get_attn_backend(16, torch.float32, None)
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation."""
......@@ -367,6 +397,8 @@ def test_per_head_quant_scales_backend_selection(
attention_config=attention_config, cache_config=cache_config
)
if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with (
set_current_vllm_config(vllm_config),
patch("vllm.platforms.current_platform", CudaPlatform()),
......
......@@ -48,7 +48,7 @@ def get_attn_isa(
else:
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
return "neon"
elif torch._C._cpu._is_amx_tile_supported():
elif torch.cpu._is_amx_tile_supported():
return "amx"
else:
return "vec"
......@@ -400,9 +400,7 @@ def test_varlen_with_paged_kv_normal_vec(
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["amx"])
@pytest.mark.skipif(
not torch._C._cpu._is_amx_tile_supported(), reason="no AMX support."
)
@pytest.mark.skipif(not torch.cpu._is_amx_tile_supported(), reason="no AMX support.")
def test_varlen_with_paged_kv_normal_amx(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Standalone unit tests for trtllm_prefill_attn_kvfp8_dequant.
Tests both contiguous and non-contiguous (cross-layer unified) KV cache
layouts against a pure-PyTorch reference implementation.
"""
import pytest
import torch
from vllm.platforms import current_platform
if current_platform.is_rocm():
pytest.skip(
"trtllm kvfp8 dequant is not supported on ROCm.",
allow_module_level=True,
)
FP8_DTYPE = current_platform.fp8_dtype()
NUM_BLOCKS = 128
def to_float8(x, dtype=None):
if dtype is None:
dtype = FP8_DTYPE
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax * 0.1
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
def make_contiguous_kv_cache(num_blocks, num_kv_heads, block_size, head_size):
"""Create a standard contiguous fp8 KV cache (HND layout)."""
raw = torch.randn(
num_blocks,
2,
num_kv_heads,
block_size,
head_size,
dtype=torch.bfloat16,
device="cuda",
)
kv_cache, scale = to_float8(raw)
return kv_cache, scale
def make_cross_layer_kv_cache(
num_blocks,
num_kv_heads,
block_size,
head_size,
num_layers=4,
):
"""
Create a non-contiguous per-layer view mimicking cross-layer allocation.
Physical layout: (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
Returned view: (num_blocks, 2, num_kv_heads, block_size, head_size)
with non-contiguous strides on dims 0, 1, 2 (they skip over num_layers).
"""
raw = torch.randn(
num_blocks,
2,
num_kv_heads,
num_layers,
block_size,
head_size,
dtype=torch.bfloat16,
device="cuda",
)
fp8_full, scale = to_float8(raw)
layer_view = fp8_full[:, :, :, 0, :, :]
assert not layer_view.is_contiguous(), (
f"Expected non-contiguous view, got strides {layer_view.stride()}"
)
return layer_view, scale
def ref_dequant(kv_cache, block_tables, k_scale, v_scale, dequant_dtype):
"""Pure PyTorch reference: gather pages and dequantize fp8 -> dequant_dtype."""
batch_size, num_pages_per_seq = block_tables.shape
s = kv_cache.shape
out = torch.zeros(
batch_size * num_pages_per_seq + 1,
s[1],
s[2],
s[3],
s[4],
dtype=dequant_dtype,
device=kv_cache.device,
)
for b in range(batch_size):
for p in range(num_pages_per_seq):
page_idx = block_tables[b, p].item()
if page_idx <= 0:
continue
mock_idx = b * num_pages_per_seq + p + 1
out[mock_idx, 0] = (kv_cache[page_idx, 0].float() * k_scale.item()).to(
dequant_dtype
)
out[mock_idx, 1] = (kv_cache[page_idx, 1].float() * v_scale.item()).to(
dequant_dtype
)
return out
@pytest.mark.parametrize("num_kv_heads", [1, 8])
@pytest.mark.parametrize("head_size", [64, 128])
@pytest.mark.parametrize("block_size", [16, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("num_pages_per_seq", [3, 8])
@pytest.mark.parametrize("contiguous", [True, False])
@torch.inference_mode()
def test_trtllm_kvfp8_dequant(
num_kv_heads: int,
head_size: int,
block_size: int,
batch_size: int,
num_pages_per_seq: int,
contiguous: bool,
):
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
if contiguous:
kv_cache, scale = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
else:
kv_cache, scale = make_cross_layer_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = scale.clone()
v_scale = scale.clone()
block_tables = torch.randint(
1,
NUM_BLOCKS,
(batch_size, num_pages_per_seq),
dtype=torch.int32,
)
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
expected_bt = torch.arange(
1,
batch_size * num_pages_per_seq + 1,
dtype=torch.int32,
device="cuda",
).reshape(batch_size, num_pages_per_seq)
torch.testing.assert_close(mock_block_table, expected_bt)
# Page 0 is padding (never written), compare only pages 1+
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
@torch.inference_mode()
def test_block_tables_with_zero_pages():
"""Pages with index <= 0 must be skipped (early return in kernel)."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 64
kv_cache, scale = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = v_scale = scale.clone()
# Mix of valid pages and zeros (padding)
block_tables = torch.tensor(
[[5, 0, 10], [0, 0, 0], [3, 7, 0]],
dtype=torch.int32,
device="cuda",
)
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
# Only compare pages that were actually written (non-zero page indices)
for b in range(block_tables.shape[0]):
for p in range(block_tables.shape[1]):
if block_tables[b, p].item() > 0:
idx = b * block_tables.shape[1] + p + 1
torch.testing.assert_close(
mock_kv_cache[idx],
ref[idx],
atol=1e-3,
rtol=1e-3,
)
@torch.inference_mode()
def test_all_zero_block_tables():
"""All-zero block_tables: kernel should write nothing."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 4, 16, 64
kv_cache, scale = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = v_scale = scale.clone()
block_tables = torch.zeros(2, 4, dtype=torch.int32, device="cuda")
# Should not crash even though no pages are valid
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
assert mock_kv_cache.shape[0] == 2 * 4 + 1
assert mock_block_table.shape == (2, 4)
@torch.inference_mode()
def test_different_k_v_scales():
"""Verify K and V are dequantized with independent scales."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 64
kv_cache, _ = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
v_scale = torch.tensor([2.0], dtype=torch.float32, device="cuda")
block_tables = torch.tensor([[1, 2]], dtype=torch.int32, device="cuda")
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
@torch.inference_mode()
def test_single_page_per_seq():
"""Minimum grid dim 1 = 1 page per sequence."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 128
kv_cache, scale = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = v_scale = scale.clone()
block_tables = torch.tensor([[5], [10], [20]], dtype=torch.int32, device="cuda")
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
@torch.inference_mode()
def test_large_page_indices():
"""Page indices near the top of the buffer stress offset arithmetic."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 128
large_num_blocks = 32768
kv_cache, scale = make_contiguous_kv_cache(
large_num_blocks,
num_kv_heads,
block_size,
head_size,
)
k_scale = v_scale = scale.clone()
# Use page indices near the top of the buffer
block_tables = torch.tensor(
[[large_num_blocks - 1, large_num_blocks - 2, 1]],
dtype=torch.int32,
device="cuda",
)
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
@torch.inference_mode()
def test_large_block_size():
"""block_size=64 -> HEAD_STRIDE=8192, large tl.arange per thread block."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 4, 64, 128
kv_cache, scale = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = v_scale = scale.clone()
block_tables = torch.randint(
1,
NUM_BLOCKS,
(2, 4),
dtype=torch.int32,
device="cuda",
)
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
@torch.inference_mode()
def test_cross_layer_many_layers():
"""
Non-contiguous with 36 layers -- matches real gpt-oss-120b.
Strides are far from contiguous (factor of 36 in the gaps).
"""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 64
num_layers = 36
kv_cache, scale = make_cross_layer_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
num_layers=num_layers,
)
k_scale = v_scale = scale.clone()
block_tables = torch.randint(
1,
NUM_BLOCKS,
(4, 6),
dtype=torch.int32,
device="cuda",
)
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
......@@ -280,21 +280,22 @@ def test_rms_norm(
assert torch.allclose(ref_residual, ops_residual)
output = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
if group_size is None:
scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
opcheck(
torch.ops._C.rms_norm_dynamic_per_token_quant,
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
)
else:
# TODO(luka/eliza) opcheck is broken?
# Somehow the cloned args are getting mutated in-place,
# which causes the opcheck to fail.
# https://github.com/vllm-project/vllm/issues/36688
return
assert hidden_size % group_size[1] == 0
num_groups = hidden_size // group_size[1]
scales = torch.empty(
(num_groups, num_tokens),
device=x.device,
dtype=torch.float32,
).transpose(0, 1)
opcheck(
torch.ops._C.rms_norm_per_block_quant,
(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import tempfile
from collections.abc import Callable
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import patch
import helion
from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import register_kernel
from vllm.kernels.helion.utils import get_canonical_gpu_name
GPU_PLATFORM = get_canonical_gpu_name()
DEFAULT_CONFIGS: dict[str, helion.Config] = {
"default": helion.Config(block_sizes=[32]),
}
@contextmanager
def dummy_kernel_registry(
configs: dict[str, helion.Config] | None = None,
):
"""Context manager providing a register function with automatic config setup.
Yields a ``register`` callable with the same signature as
``register_kernel``. Before applying the real decorator it writes a
config JSON for the kernel name (from ``op_name`` or ``fn.__name__``)
into a temporary directory backed by a fresh ``ConfigManager``.
"""
if configs is None:
configs = DEFAULT_CONFIGS
config_data = {k: v.__dict__["config"] for k, v in configs.items()}
with tempfile.TemporaryDirectory() as tmpdir:
config_dir = Path(tmpdir)
ConfigManager.reset_instance()
cm = ConfigManager(base_dir=config_dir)
with patch(
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=cm,
):
def register(
op_name: str | None = None,
**kwargs,
) -> Callable:
def decorator(fn: Callable) -> Callable:
name = op_name or fn.__name__
kernel_dir = config_dir / name
kernel_dir.mkdir(parents=True, exist_ok=True)
(kernel_dir / f"{GPU_PLATFORM}.json").write_text(
json.dumps(config_data)
)
return register_kernel(op_name, **kwargs)(fn)
return decorator
try:
yield register
finally:
ConfigManager.reset_instance()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for autotuning Helion kernels, including disabled kernels with no configs."""
import pytest
import torch
from vllm.utils.import_utils import has_helion
if not has_helion():
pytest.skip(
"Helion is not installed. Install with: pip install vllm[helion]",
allow_module_level=True,
)
import helion
import helion.language as hl
from helion.autotuner.base_search import BaseSearch
from tests.kernels.helion.helpers import dummy_kernel_registry
from vllm.kernels.helion.register import create_helion_decorated_kernel
def _add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.size()):
out[tile] = x[tile] + y[tile]
return out
class NoCompileSearch(BaseSearch):
"""Autotuner that returns the default config without GPU compilation.
Modeled after helion's test BasicSearch (pytorch/helion#1649).
"""
def autotune(self, *, skip_cache: bool = False):
return self.config_spec.default_config()
def _no_compile_autotuner_fn(bound_kernel, args, **kwargs):
return NoCompileSearch(bound_kernel, args, **kwargs)
class TestAutotuneDisabledKernel:
"""Test autotuning flow on disabled kernels (no platform configs)."""
def setup_method(self):
from vllm.kernels.helion.register import _REGISTERED_KERNELS
self._saved_registry = dict(_REGISTERED_KERNELS)
_REGISTERED_KERNELS.clear()
def teardown_method(self):
from vllm.kernels.helion.register import _REGISTERED_KERNELS
_REGISTERED_KERNELS.clear()
_REGISTERED_KERNELS.update(self._saved_registry)
def test_autotune_disabled_kernel_produces_valid_config(self):
"""Register a kernel with no configs (disabled), run autotune,
verify it produces a valid helion.Config."""
with dummy_kernel_registry(configs={}) as register:
wrapper = register(
"autotune_test_kernel",
config_picker=lambda args, keys: "default",
fake_impl=lambda *a, **kw: None,
input_generator=lambda: {
"small": (
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device="cuda"),
),
},
)(_add_kernel)
assert wrapper._disabled is True
inputs = wrapper.get_inputs()
assert "small" in inputs
settings = helion.Settings()
settings.autotuner_fn = _no_compile_autotuner_fn
wrapper.helion_settings = settings
config = wrapper.run_autotune(inputs["small"])
expected_default = (
create_helion_decorated_kernel(_add_kernel, helion_settings=settings)
.bind(inputs["small"])
.config_spec.default_config()
)
assert config == expected_default
......@@ -52,7 +52,7 @@ def _helion_mock_context():
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
......@@ -87,8 +87,8 @@ class TestMakeFxHop:
raw_kernel_func=raw_add_scale,
op_name="test_make_fx",
fake_impl=lambda *a, **kw: None,
config_picker=lambda args, keys: "default",
)
wrapper.register_config_picker(lambda args, keys: "default")
def fn(x, y):
return wrapper(x, y, scale)
......@@ -143,8 +143,8 @@ class TestMakeFxHop:
raw_kernel_func=raw_silu_mul,
op_name="test_pm_silu_mul",
fake_impl=lambda *a, **kw: None,
config_picker=lambda args, keys: "default",
)
wrapper.register_config_picker(lambda args, keys: "default")
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(x) * y
......
......@@ -21,7 +21,9 @@ if not has_helion():
)
import helion
import helion.language as hl
from tests.kernels.helion.helpers import dummy_kernel_registry
from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import (
_HOP_AVAILABLE,
......@@ -34,6 +36,13 @@ from vllm.kernels.helion.register import (
)
def _add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.size()):
out[tile] = x[tile] + y[tile]
return out
@pytest.fixture
def sample_configs():
"""Create real Helion config objects for testing."""
......@@ -90,7 +99,7 @@ def configured_kernel(sample_kernel, sample_configs, config_manager_with_test_co
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=config_manager_with_test_configs,
),
patch(
......@@ -158,7 +167,7 @@ def create_configured_kernel_with_configs(
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
......@@ -189,7 +198,7 @@ class TestConfiguredHelionKernel:
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
......@@ -266,7 +275,7 @@ class TestConfiguredHelionKernel:
with (
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
......@@ -310,7 +319,7 @@ class TestConfiguredHelionKernel:
with (
patch("vllm.kernels.helion.register.helion.kernel") as mock_helion_kernel,
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
......@@ -346,23 +355,15 @@ class TestConfiguredHelionKernel:
class TestHelionKernelWrapper:
"""Test suite for HelionKernelWrapper."""
def test_get_configured_op_validates_configs_available(self, sample_kernel):
"""Test get_configured_op validates configs are available."""
def test_init_disables_on_missing_configs(self, sample_kernel):
"""Test __init__ marks wrapper as disabled when configs are missing."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
)
def default_picker(args, config_keys):
return "default"
wrapper._config_picker = default_picker
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(
return_value={}
......@@ -370,52 +371,99 @@ class TestHelionKernelWrapper:
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
pytest.raises(ValueError, match="No configs available"),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
wrapper.get_configured_op()
mock_kernel.return_value = Mock(return_value=sample_kernel)
def test_get_configured_op_validates_config_picker(
self, sample_kernel, sample_configs
):
"""Test get_configured_op validates config picker."""
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
assert wrapper._disabled is True
assert "No configs available" in wrapper._disabled_reason
def test_disabled_wrapper_raises_on_call(self, sample_kernel):
"""Test __call__ raises RuntimeError on a disabled wrapper."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
)
# Don't set config picker - should raise assertion error
def default_picker(args, config_keys):
return "default"
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
mock_config_manager.get_platform_configs = Mock(return_value={})
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_kernel.return_value = Mock(return_value=sample_kernel)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
with pytest.raises(RuntimeError, match="is disabled"):
wrapper(torch.randn(4, 4), torch.randn(4, 4))
def test_disabled_wrapper_get_configured_op_raises(self, sample_kernel):
"""Test get_configured_op raises RuntimeError on a disabled wrapper."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
def default_picker(args, config_keys):
return "default"
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value={})
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
pytest.raises(AssertionError, match="No config picker registered"),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_kernel.return_value = Mock(return_value=sample_kernel)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
with pytest.raises(RuntimeError, match="is disabled"):
wrapper.get_configured_op()
def test_get_configured_op_returns_cached_kernel(
self, sample_kernel, sample_configs
):
"""Test get_configured_op returns cached ConfiguredHelionKernel."""
def test_disabled_wrapper_supports_get_inputs(self, sample_kernel):
"""Test get_inputs works on a disabled wrapper."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
......@@ -423,19 +471,99 @@ class TestHelionKernelWrapper:
def default_picker(args, config_keys):
return "default"
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
)
wrapper._config_picker = default_picker
expected_inputs = {"key1": (torch.randn(4),)}
input_gen = Mock(return_value=expected_inputs)
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value={})
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_kernel.return_value = Mock(return_value=sample_kernel)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
input_generator=input_gen,
)
assert wrapper._disabled is True
result = wrapper.get_inputs()
assert result is expected_inputs
def test_disabled_wrapper_supports_run_autotune(self, sample_kernel):
"""Test run_autotune works on a disabled wrapper."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
def default_picker(args, config_keys):
return "default"
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value={})
mock_config = Mock()
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_kernel.return_value = Mock(return_value=sample_kernel)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
assert wrapper._disabled is True
with patch(
"vllm.kernels.helion.register.create_helion_decorated_kernel"
) as mock_create:
mock_autotune_kernel = Mock()
mock_autotune_kernel.autotune.return_value = mock_config
mock_create.return_value = mock_autotune_kernel
inputs = (torch.randn(4, 4),)
result = wrapper.run_autotune(inputs)
assert result is mock_config
def test_init_caches_configured_kernel(self, sample_kernel, sample_configs):
"""Test __init__ eagerly builds and caches ConfiguredHelionKernel."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
def default_picker(args, config_keys):
return "default"
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
......@@ -444,13 +572,77 @@ class TestHelionKernelWrapper:
),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_decorated = Mock()
mock_kernel.return_value = Mock(return_value=mock_decorated)
mock_kernel.return_value = Mock(return_value=sample_kernel)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
assert wrapper._configured_kernel is not None
result1 = wrapper.get_configured_op()
result2 = wrapper.get_configured_op()
assert result1 is result2
@pytest.mark.skipif(
not _HOP_AVAILABLE, reason="HOP path only used when HOP available"
)
def test_init_eagerly_initializes_hop_path(self):
"""Test that register_kernel eagerly builds the configured kernel
on the HOP path (no custom op registration needed)."""
from vllm.kernels.helion.utils import get_canonical_gpu_name
configs = {"default": helion.Config(block_sizes=[4, 4])}
with (
dummy_kernel_registry(configs=configs) as register,
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
wraps=get_canonical_gpu_name,
) as mock_gpu,
):
wrapper = register(
config_picker=lambda args, keys: "default",
)(_add_kernel)
mock_gpu.assert_called_once()
assert wrapper._configured_kernel is not None
with patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
side_effect=AssertionError("get_canonical_gpu_name called during __call__"),
):
x = torch.randn(4, 4, device="cuda")
y = torch.randn(4, 4, device="cuda")
result = wrapper(x, y)
expected = x + y
assert torch.allclose(result, expected)
@pytest.mark.skipif(
_HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
)
def test_init_eagerly_initializes(self):
"""Test that register_kernel eagerly loads configs and detects GPU
during construction so __call__ needs no further initialization."""
from vllm.kernels.helion.utils import get_canonical_gpu_name
with (
dummy_kernel_registry() as register,
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
wraps=get_canonical_gpu_name,
) as mock_gpu,
):
wrapper = register(
config_picker=lambda args, keys: "default",
)(_add_kernel)
# Init must have detected GPU and built the kernel
mock_gpu.assert_called_once()
assert wrapper._configured_kernel is not None
assert hasattr(torch.ops.vllm_helion, wrapper.op_name)
@pytest.mark.skipif(
_HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
)
......@@ -463,13 +655,6 @@ class TestHelionKernelWrapper:
def default_picker(args, config_keys):
return "default"
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
)
wrapper._config_picker = default_picker
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
......@@ -479,7 +664,7 @@ class TestHelionKernelWrapper:
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
......@@ -491,6 +676,13 @@ class TestHelionKernelWrapper:
):
mock_decorated = Mock()
mock_kernel.return_value = Mock(return_value=mock_decorated)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
result = wrapper._get_or_register_custom_op()
assert result is existing_op
......@@ -506,13 +698,6 @@ class TestHelionKernelWrapper:
def default_picker(args, config_keys):
return "default"
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
)
wrapper._config_picker = default_picker
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
......@@ -532,7 +717,7 @@ class TestHelionKernelWrapper:
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
......@@ -548,6 +733,13 @@ class TestHelionKernelWrapper:
):
mock_decorated = Mock()
mock_kernel.return_value = Mock(return_value=mock_decorated)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
result = wrapper._get_or_register_custom_op()
mock_register.assert_called_once()
......@@ -584,11 +776,10 @@ class TestKernelRegistry:
def test_get_kernel_by_name_returns_kernel(self):
"""Test get_kernel_by_name returns registered kernel."""
wrapper = HelionKernelWrapper(
raw_kernel_func=Mock(),
op_name="test_kernel",
fake_impl=Mock(),
)
with dummy_kernel_registry() as register:
wrapper = register(
"test_kernel", config_picker=lambda args, keys: "default"
)(_add_kernel)
from vllm.kernels.helion.register import _REGISTERED_KERNELS
......@@ -604,112 +795,87 @@ class TestKernelRegistry:
def test_register_kernel_auto_generates_fake_impl(self):
"""Test register_kernel auto-generates fake_impl when not provided."""
with patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer:
with (
dummy_kernel_registry() as register,
patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer,
):
mock_fake = Mock()
mock_infer.return_value = mock_fake
wrapper = register(
config_picker=lambda args, keys: "default",
)(_add_kernel)
def original_kernel(x):
return x
wrapper = register_kernel(original_kernel)
mock_infer.assert_called_once_with(original_kernel, None)
assert wrapper._fake_impl is mock_fake
mock_infer.assert_called_once_with(_add_kernel, None)
assert wrapper._fake_impl is mock_fake
def test_register_kernel_creates_wrapper(self):
"""Test register_kernel creates HelionKernelWrapper."""
def test_kernel(x):
return x
result = register_kernel("test_name")(test_kernel)
with dummy_kernel_registry() as register:
result = register("test_name", config_picker=lambda args, keys: "default")(
_add_kernel
)
assert isinstance(result, HelionKernelWrapper)
assert result.op_name == "test_name"
assert result.raw_kernel_func is test_kernel
assert result.raw_kernel_func is _add_kernel
def test_register_kernel_auto_detects_name(self):
"""Test register_kernel uses function name when no name provided."""
with dummy_kernel_registry() as register:
wrapper = register(config_picker=lambda args, keys: "default")(_add_kernel)
@register_kernel
def my_test_kernel(x):
return x
assert my_test_kernel.op_name == "my_test_kernel"
assert wrapper.op_name == "_add_kernel"
def test_register_kernel_registers_in_global_registry(self):
"""Test register_kernel adds wrapper to global registry."""
@register_kernel
def test_kernel(x):
return x
with dummy_kernel_registry() as register:
wrapper = register(
"test_kernel", config_picker=lambda args, keys: "default"
)(_add_kernel)
registered_kernels = get_registered_kernels()
assert "test_kernel" in registered_kernels
assert registered_kernels["test_kernel"] is test_kernel
assert registered_kernels["test_kernel"] is wrapper
def test_register_kernel_passes_helion_settings(self):
"""Test register_kernel passes helion_settings to wrapper."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"debug": True}
settings = helion.Settings()
settings.print_output_code = True
@register_kernel("test_name", helion_settings=mock_settings)
def test_kernel(x):
return x
with dummy_kernel_registry() as register:
result = register(
"test_name",
config_picker=lambda args, keys: "default",
helion_settings=settings,
)(_add_kernel)
assert test_kernel.helion_settings is mock_settings
assert result.helion_settings is settings
def test_register_kernel_supports_decorator_syntax(self):
"""Test register_kernel works with decorator arguments."""
mock_fake = Mock()
wrapper = register_kernel("custom_name", fake_impl=mock_fake)
def test_kernel(x):
return x
result = wrapper(test_kernel)
with dummy_kernel_registry() as register:
result = register(
"custom_name",
config_picker=lambda args, keys: "default",
fake_impl=mock_fake,
)(_add_kernel)
assert result.op_name == "custom_name"
assert result._fake_impl is mock_fake
def test_register_kernel_bare_decorator(self):
"""Test register_kernel works as bare decorator."""
@register_kernel
def test_kernel(x):
return x
assert isinstance(test_kernel, HelionKernelWrapper)
assert test_kernel.op_name == "test_kernel"
def test_registered_wrapper_can_register_config_picker(self):
"""Test that registered wrapper can register config picker."""
@register_kernel
def test_kernel(x):
return x
def my_picker(args, config_keys):
return "default"
result = test_kernel.register_config_picker(my_picker)
assert result is my_picker
assert test_kernel._config_picker is my_picker
def test_register_kernel_raises_on_duplicate_registration(self):
"""Test register_kernel raises error on duplicate names."""
with dummy_kernel_registry() as register:
register("duplicate_name", config_picker=lambda args, keys: "default")(
_add_kernel
)
@register_kernel("duplicate_name")
def kernel1(x):
return x
with pytest.raises(ValueError, match="already registered"):
@register_kernel("duplicate_name")
def kernel2(x):
return x
with pytest.raises(ValueError, match="already registered"):
register("duplicate_name", config_picker=lambda args, keys: "default")(
_add_kernel
)
def test_register_kernel_rejects_autotuner_fn_in_settings(self):
"""Test register_kernel rejects conflicting autotuner_fn."""
......@@ -718,7 +884,11 @@ class TestKernelRegistry:
with pytest.raises(ValueError, match="uses a custom autotuner"):
@register_kernel("test", helion_settings=mock_settings)
@register_kernel(
"test",
config_picker=lambda args, keys: "default",
helion_settings=mock_settings,
)
def test_kernel(x):
return x
......@@ -727,11 +897,47 @@ class TestKernelRegistry:
mock_settings = Mock()
mock_settings.to_dict.return_value = {"static_shapes": False}
with patch("vllm.kernels.helion.register.logger") as mock_logger:
with (
dummy_kernel_registry() as register,
patch("vllm.kernels.helion.register.logger") as mock_logger,
):
register(
"test",
config_picker=lambda args, keys: "default",
helion_settings=mock_settings,
)(_add_kernel)
@register_kernel("test", helion_settings=mock_settings)
def test_kernel(x):
return x
mock_logger.warning.assert_not_called()
# Should not call warning
mock_logger.warning.assert_not_called()
def test_disabled_kernel_appears_in_registry(self):
"""Test that a disabled wrapper is still in the global registry."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value={})
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_kernel.return_value = Mock(return_value=_add_kernel)
wrapper = register_kernel(
"disabled_kernel",
config_picker=lambda args, keys: "default",
fake_impl=fake_impl,
)(_add_kernel)
assert wrapper._disabled is True
registered = get_registered_kernels()
assert "disabled_kernel" in registered
assert registered["disabled_kernel"] is wrapper
......@@ -22,7 +22,7 @@ INTERMEDIATE_DIM = [128, 2880]
BATCH_SIZE = [1, 64, 256]
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI]
USE_BIAS = [True, False]
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
ISA = ["amx", "vec"] if torch.cpu._is_amx_tile_supported() else ["vec"]
DTYPE = [torch.bfloat16]
......
......@@ -6,6 +6,7 @@ import pytest
import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_triton_kernels
if not has_triton_kernels():
......@@ -14,6 +15,7 @@ if not has_triton_kernels():
allow_module_level=True,
)
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
from triton_kernels.numerics import InFlexData
......@@ -21,12 +23,16 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close
from triton_kernels.topk import topk as topk_fn
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
legacy_routing,
make_routing_data,
triton_kernel_moe_forward,
)
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import set_random_seed
from .utils import shuffle_weight
......@@ -299,6 +305,12 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
if current_platform.is_device_capability_family(100):
constraints = {
"is_persistent": True,
}
opt_flags.update_opt_flags_constraints(constraints)
if a_dtype == "bf16" and w_dtype == "mx4":
quant_config = mxfp4_w4a16_moe_quant_config(
w1_scale=pc1,
......@@ -355,3 +367,43 @@ def test_unit_shuffle():
)
assert_close(ref=out_ref, tri=out)
@pytest.mark.parametrize("num_tokens", [2, 8, 64])
@pytest.mark.parametrize("num_experts", [32, 128])
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_legacy_routing(
num_tokens: int, num_experts: int, topk: int, renormalize: bool, dtype: torch.dtype
):
set_random_seed(0)
gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
sm_first = not renormalize
logits = gating_output
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first)
topk_ids = sparse_logits.indx.to(torch.long)
topk_weights = sparse_logits.vals
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
topk_ids, topk_weights, num_experts
)
routing_data, gather_indx, scatter_indx = legacy_routing(
gating_output, topk, sm_first=sm_first
)
assert_close(
ref=gather_indx_ref.src_indx, tri=gather_indx.src_indx, maxtol=0, rmstol=0
)
assert_close(
ref=gather_indx_ref.dst_indx, tri=gather_indx.dst_indx, maxtol=0, rmstol=0
)
assert_close(
ref=scatter_indx_ref.src_indx, tri=scatter_indx.src_indx, maxtol=0, rmstol=0
)
assert_close(
ref=scatter_indx_ref.dst_indx, tri=scatter_indx.dst_indx, maxtol=0, rmstol=0
)
......@@ -82,7 +82,7 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
model_case.model_id,
tensor_parallel_size=model_case.tp,
load_format="dummy",
cudagraph_capture_sizes=[16],
compilation_config={"cudagraph_capture_sizes": [16]},
) as llm:
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
# def check_model(model):
......
......@@ -10,7 +10,6 @@
# and the platform is not ROCm.
import importlib.util
import os
import pytest
import torch
......@@ -20,9 +19,6 @@ from vllm.platforms import current_platform
if not current_platform.is_rocm():
pytest.skip("This test can only run on ROCm.", allow_module_level=True)
# This environment variable must be set so ops will be registered.
os.environ["VLLM_ROCM_USE_AITER"] = "1"
# this import statement is needed to ensure the ops are registered
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for optimized router GEMM kernel
Run `pytest tests/kernels/moe/test_router_gemm.py`.
"""
import pytest
import torch
import vllm._custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@pytest.mark.skipif(
not (
current_platform.is_cuda()
and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability_family(100)
)
),
reason="This test only runs on Hopper or Blackwell GPUs.",
)
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8])
@pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880])
@pytest.mark.parametrize("output_dim", [32, 64, 128])
def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim):
set_random_seed(0)
x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16)
bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16)
output = ops.gpt_oss_router_gemm(x, weight, bias)
output_ref = torch.nn.functional.linear(x, weight, bias)
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment