Commit 7e63ef82 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0' into v0.14.0-dev

parents 8cbcac5d b17039bc
...@@ -5,10 +5,12 @@ import functools ...@@ -5,10 +5,12 @@ import functools
import multiprocessing import multiprocessing
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path
import pytest import pytest
import torch import torch
import vllm.model_executor.layers.activation
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
...@@ -16,9 +18,19 @@ from vllm.config import ( ...@@ -16,9 +18,19 @@ from vllm.config import (
VllmConfig, VllmConfig,
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.envs import disable_envs_cache
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test
@pytest.fixture
def vllm_tmp_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
"""Fixture that sets VLLM_CACHE_ROOT to a temporary directory."""
monkeypatch.setenv("VLLM_CACHE_ROOT", str(tmp_path / "vllm_cache"))
return tmp_path
def reference_fn(x: torch.Tensor): def reference_fn(x: torch.Tensor):
assert x.shape[0] <= 42 assert x.shape[0] <= 42
...@@ -66,6 +78,7 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch): ...@@ -66,6 +78,7 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
torch.compiler.set_stance("fail_on_recompile"), torch.compiler.set_stance("fail_on_recompile"),
): ):
CompiledMod(vllm_config=vllm_config)(*args) CompiledMod(vllm_config=vllm_config)(*args)
disable_envs_cache()
m.setenv("VLLM_USE_AOT_COMPILE", "1") m.setenv("VLLM_USE_AOT_COMPILE", "1")
torch._dynamo.reset() torch._dynamo.reset()
...@@ -101,6 +114,7 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch): ...@@ -101,6 +114,7 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
with use_vllm_config(vllm_config): with use_vllm_config(vllm_config):
expected = CompiledMod(vllm_config=vllm_config)(*args) expected = CompiledMod(vllm_config=vllm_config)(*args)
disable_envs_cache()
m.setenv("VLLM_FORCE_AOT_LOAD", "1") m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
...@@ -130,6 +144,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): ...@@ -130,6 +144,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
artifacts = compiled_mod.aot_compiled_fn._artifacts artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards() guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
disable_envs_cache()
m.setenv("VLLM_FORCE_AOT_LOAD", "1") m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
...@@ -144,7 +159,94 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): ...@@ -144,7 +159,94 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
@pytest.mark.skipif( @pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
) )
@use_vllm_config(make_vllm_config()) def test_partition_wrapper_applied_on_aot_load(
monkeypatch: pytest.MonkeyPatch, vllm_tmp_cache: Path, mocker
):
"""
Test that partition wrappers are applied when loading AOT cached functions.
This test verifies the fix for GitHub issue #31439 where AOT compile
caused 2x latency regression when use_inductor_graph_partition=True.
The root cause was that partition wrapper context was bypassed when
loading from AOT cache.
"""
from vllm.config import CUDAGraphMode
args = (torch.randn(10, 10),)
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
# Create config with partition enabled
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# First compilation - save to cache
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
disable_envs_cache()
# Second run - load from cache, verify partition wrapper applied
monkeypatch.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# Use mocker to spy on set_customized_partition_wrappers
spy = mocker.spy(torch._inductor.utils, "set_customized_partition_wrappers")
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
# First call after restart: loads from AOT cache.
# This tests the fix for the first call after a restart.
compiled_mod(*args)
# Verify partition wrapper was called on AOT load.
assert spy.call_count >= 2, (
"Expected partition wrapper to be set and cleared on AOT load, "
f"got {spy.call_count} calls"
)
# First call should set a wrapper, last call should clear it
assert spy.call_args_list[0][0][0] is not None, (
"First call on AOT load should set a wrapper function"
)
assert spy.call_args_list[-1][0][0] is None, (
"Last call on AOT load should clear the wrapper"
)
# Reset for the next check.
spy.reset_mock()
# Subsequent call: uses the cached `aot_compiled_fn`.
# This tests the fix for subsequent calls.
compiled_mod(*args)
# Verify partition wrapper was called on the subsequent call.
assert spy.call_count >= 2, (
"Expected partition wrapper set and cleared on subsequent "
f"call, got {spy.call_count} calls"
)
assert spy.call_args_list[0][0][0] is not None, (
"First call on subsequent call should set a wrapper function"
)
assert spy.call_args_list[-1][0][0] is None, (
"Last call on subsequent call should clear the wrapper"
)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
@create_new_process_for_each_test("spawn")
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch): def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
""" """
Test that compiling gpt2 twice results in a cache hit and Test that compiling gpt2 twice results in a cache hit and
...@@ -186,6 +288,8 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch): ...@@ -186,6 +288,8 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
# Clean up first model # Clean up first model
del llm_model del llm_model
disable_envs_cache()
vllm.model_executor.layers.activation._ACTIVATION_REGISTRY._dict.clear()
# Second compilation - should hit cache # Second compilation - should hit cache
m.setenv("VLLM_FORCE_AOT_LOAD", "1") m.setenv("VLLM_FORCE_AOT_LOAD", "1")
......
...@@ -15,7 +15,10 @@ from vllm.config.compilation import CompilationMode, PassConfig ...@@ -15,7 +15,10 @@ from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.logger import _print_warning_once from vllm.logger import _print_warning_once
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer from vllm.utils.torch_utils import (
_is_torch_equal_or_newer,
is_torch_equal,
)
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401 from . import silly_attention # noqa: F401
...@@ -30,6 +33,29 @@ def test_version(): ...@@ -30,6 +33,29 @@ def test_version():
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev") assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
def test_get_raw_stream_patch():
"""Test that get_raw_stream patch is applied only for torch 2.9.0 or 2.9.1."""
import builtins
# Check if get_raw_stream exists in builtins
has_patch = hasattr(builtins, "get_raw_stream")
# Import torch to get actual version
is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1")
if is_torch_2_9:
# For torch 2.9.x, the patch should be applied
assert has_patch, "get_raw_stream should be patched for torch 2.9.x"
# Verify it's callable (it should be the _cuda_getCurrentRawStream function)
get_raw_stream = builtins.get_raw_stream # type: ignore[attr-defined]
assert callable(get_raw_stream)
# Verify it's the correct function from torch._C
from torch._C import _cuda_getCurrentRawStream
assert get_raw_stream is _cuda_getCurrentRawStream
def test_copy_pass(): def test_copy_pass():
vllm_config = VllmConfig() vllm_config = VllmConfig()
inductor_pass = FixFunctionalizationPass(vllm_config) inductor_pass = FixFunctionalizationPass(vllm_config)
...@@ -406,51 +432,43 @@ def test_cudagraph_sizes_post_init( ...@@ -406,51 +432,43 @@ def test_cudagraph_sizes_post_init(
) )
def test_pass_config_deprecation(caplog_vllm): def test_cached_compilation_config(default_vllm_config):
caplog_vllm.set_level(logging.WARNING) import torch
from torch._inductor.utils import run_and_get_code
# Clear cache to ensure warnings are re-issued
_print_warning_once.cache_clear() from vllm.config import get_cached_compilation_config, set_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
caplog_vllm.clear()
config = PassConfig(enable_fusion=True) dtype = torch.bfloat16
assert "enable_fusion is deprecated" in caplog_vllm.text device = torch.device("cuda:0")
assert config.fuse_norm_quant is True batch_size, num_qo_heads, head_size = 8, 16, 128
assert config.fuse_act_quant is True
assert config.enable_fusion is None # access and cache default compilation config
# default compilation config does not contain +quant_fp8 custom op. If this is
# Test enable_attn_fusion -> fuse_attn_quant # used, the generated code would use inductor-generated triton kernel instead
caplog_vllm.clear() # of the custom op `torch.ops._C.static_scaled_fp8_quant`.
config = PassConfig(enable_attn_fusion=True) get_cached_compilation_config()
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
assert config.fuse_attn_quant is True vllm_config = VllmConfig(
assert config.enable_attn_fusion is None compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
# Test enable_noop -> eliminate_noops custom_ops=["+quant_fp8"],
caplog_vllm.clear() )
config = PassConfig(enable_noop=True) )
assert "enable_noop is deprecated" in caplog_vllm.text
assert config.eliminate_noops is True # set_current_vllm_config should clear cached compilation config and
assert config.enable_noop is None # use the new compilation_config in vllm_config
with set_current_vllm_config(vllm_config):
# Test enable_sequence_parallelism -> enable_sp query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
caplog_vllm.clear() query_quant = torch.compile(query_quant)
config = PassConfig(enable_sequence_parallelism=True)
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text _q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
assert config.enable_sp is True query = torch.randn(
assert config.enable_sequence_parallelism is None batch_size, num_qo_heads * head_size, dtype=dtype, device=device
)
# Test enable_async_tp -> fuse_gemm_comms
caplog_vllm.clear() _, code = run_and_get_code(query_quant, query, _q_scale)
config = PassConfig(enable_async_tp=True)
assert "enable_async_tp is deprecated" in caplog_vllm.text code = " ".join(code)
assert config.fuse_gemm_comms is True assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
assert config.enable_async_tp is None
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
caplog_vllm.clear()
config = PassConfig(enable_fi_allreduce_fusion=True)
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
assert config.fuse_allreduce_rms is True
assert config.enable_fi_allreduce_fusion is None
...@@ -77,6 +77,7 @@ def test_dynamic_shapes_compilation( ...@@ -77,6 +77,7 @@ def test_dynamic_shapes_compilation(
"evaluate_guards": evaluate_guards, "evaluate_guards": evaluate_guards,
}, },
}, },
max_model_len=1024,
) )
output = model.generate(prompt) output = model.generate(prompt)
......
...@@ -25,10 +25,13 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size): ...@@ -25,10 +25,13 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.pos_embed = torch.empty(buffer_size, hidden_size, dtype=dtype) # Avoid using empty, since on rocm torch.empty
# does not initialize the memory.
self.pos_embed = torch.randn(buffer_size, hidden_size, dtype=dtype)
def forward(self, x): def forward(self, x):
x += self.pos_embed[: x.shape[0]] # Avoid += to prevent inplace addition.
x = x + self.pos_embed[: x.shape[0]]
# Chain of reshapes # Chain of reshapes
y = x.reshape(-1, 128, 32) y = x.reshape(-1, 128, 32)
z = y.reshape(-1, 4096) z = y.reshape(-1, 4096)
......
...@@ -5,7 +5,6 @@ import pytest ...@@ -5,7 +5,6 @@ import pytest
import torch import torch
from tests.compile.backend import TestBackend from tests.compile.backend import TestBackend
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
...@@ -25,6 +24,7 @@ from vllm.config import ( ...@@ -25,6 +24,7 @@ from vllm.config import (
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backend import AttentionType
RSQRT_OP = torch.ops.aten.rsqrt.default RSQRT_OP = torch.ops.aten.rsqrt.default
INDEX_SELECT_OP = torch.ops.aten.index.Tensor INDEX_SELECT_OP = torch.ops.aten.index.Tensor
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest import pytest
import torch import torch
...@@ -53,37 +52,61 @@ class TestModel(torch.nn.Module): ...@@ -53,37 +52,61 @@ class TestModel(torch.nn.Module):
hidden_size: int, hidden_size: int,
eps: float, eps: float,
group_shape: GroupShape, group_shape: GroupShape,
cuda_force_torch: bool, use_aiter: bool = False,
cuda_force_torch: bool = False,
use_aiter_quant_op: bool = True,
*args, *args,
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.use_aiter = use_aiter
self.use_aiter_quant_op = use_aiter_quant_op
self.cuda_force_torch = cuda_force_torch self.cuda_force_torch = cuda_force_torch
self.group_shape = group_shape
self.enable_quant_fp8_custom_op = None # Will be set later if applicable
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
if group_shape.is_per_group():
self.wscale = [ # Setup quantization scale descriptor
torch.rand( static = group_shape == GroupShape.PER_TENSOR and not use_aiter
(hidden_size // group_shape[1], hidden_size // group_shape[1]),
dtype=torch.float32,
)
for _ in range(3)
]
else:
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
static = group_shape == GroupShape.PER_TENSOR
quant_scale = ScaleDesc(torch.float32, static, group_shape) quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
# Setup scales
if static: if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else: else:
self.scale = [None for _ in range(3)] self.scale = [None for _ in range(3)]
# Setup weights
self.w = [ self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
] ]
if not group_shape.is_per_group(): if not group_shape.is_per_group() or use_aiter:
self.w = [self.w[0].t() for _ in range(3)] self.w = [self.w[0].t() for _ in range(3)]
# Setup weight scales
if group_shape.is_per_group(): if group_shape.is_per_group():
scale_size = (
(hidden_size + 128 - 1) // 128
if use_aiter
else hidden_size // group_shape[1]
)
wscale_shape: tuple[int, ...] = (scale_size, scale_size)
else:
wscale_shape = (1,)
self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)]
# Setup FP8 linear operation
is_per_group = group_shape.is_per_group()
if is_per_group and use_aiter:
self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=group_shape,
use_aiter_and_is_supported=use_aiter_quant_op,
)
# AITER blockwise doesn't use enable_quant_fp8_custom_op
elif is_per_group:
self.fp8_linear = W8A8BlockFp8LinearOp( self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]), weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape, act_quant_group_shape=group_shape,
...@@ -91,6 +114,13 @@ class TestModel(torch.nn.Module): ...@@ -91,6 +114,13 @@ class TestModel(torch.nn.Module):
use_aiter_and_is_supported=False, use_aiter_and_is_supported=False,
) )
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
elif use_aiter:
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=group_shape,
)
self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
else: else:
with override_cutlass_fp8_supported(not cuda_force_torch): with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp( self.fp8_linear = Fp8LinearOp(
...@@ -100,7 +130,6 @@ class TestModel(torch.nn.Module): ...@@ -100,7 +130,6 @@ class TestModel(torch.nn.Module):
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
self.enable_rms_norm_custom_op = self.norm[0].enabled() self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.group_shape = group_shape
def forward(self, x): def forward(self, x):
# avoid having graph input be an arg to a pattern directly # avoid having graph input be an arg to a pattern directly
...@@ -126,19 +155,49 @@ class TestModel(torch.nn.Module): ...@@ -126,19 +155,49 @@ class TestModel(torch.nn.Module):
y4, resid = self.norm[3](x4, resid) # use resid here y4, resid = self.norm[3](x4, resid) # use resid here
return y4 return y4
def ops_in_model_before(self):
if (
self.use_aiter
and self.group_shape.is_per_group()
and current_platform.is_fp8_fnuz()
):
return [rocm_aiter_ops.get_group_quant_op()]
if self.use_aiter and self.group_shape.is_per_group():
return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
if self.use_aiter and self.use_aiter_quant_op:
return [rocm_aiter_ops.get_per_token_quant_op()]
if self.use_aiter:
return [QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op:
return [QUANT_OPS[self.quant_key]]
return [torch.ops.aten.reciprocal]
def ops_in_model_after(self): def ops_in_model_after(self):
if self.use_aiter and self.group_shape.is_per_group():
from vllm.compilation.rocm_aiter_fusion import (
AiterFusedAddRMSFp8GroupQuantPattern,
AiterRMSFp8GroupQuantPattern,
)
return [
AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
AiterRMSFp8GroupQuantPattern.FUSED_OP,
]
if self.use_aiter:
from vllm.compilation.rocm_aiter_fusion import (
AiterFusedAddRMSNormDynamicQuantPattern,
AiterRMSNormDynamicQuantPattern,
)
return [
AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
AiterRMSNormDynamicQuantPattern.FUSED_OP,
]
return [ return [
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
] ]
def ops_in_model_before(self):
return (
[QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal]
)
def ops_in_model_before_partial(self): def ops_in_model_before_partial(self):
return ( return (
[RMS_OP, RMS_ADD_OP] [RMS_OP, RMS_ADD_OP]
...@@ -155,67 +214,45 @@ GROUP_SHAPES = [ ...@@ -155,67 +214,45 @@ GROUP_SHAPES = [
] ]
class TestRmsnormGroupFp8QuantModel(torch.nn.Module): def _run_fusion_test(
def __init__(self, hidden_size: int, eps: float, **kwargs): model,
super().__init__() fusion_pass,
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( vllm_config,
weight_group_shape=GroupShape(128, 128), dtype,
act_quant_group_shape=GroupShape(1, 128), hidden_size,
cutlass_block_fp8_supported=False, num_tokens,
use_aiter_and_is_supported=True, ):
) """Helper function for common fusion test logic.
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
]
scale_hidden_size = (hidden_size + 128 - 1) // 128 Must be called within vllm_config context.
self.wscale = [ """
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32) noop_pass = NoOpEliminationPass(vllm_config)
for _ in range(3) cleanup_pass = PostCleanupPass(vllm_config)
]
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)] backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
self.eps = eps backend2 = TestBackend(noop_pass, cleanup_pass)
def forward(self, x): x = torch.rand(num_tokens, hidden_size)
# avoid having graph input be an arg to a pattern directly torch._dynamo.mark_dynamic(x, 0)
x = resid = torch.relu(x)
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0]) model_fused = torch.compile(model, backend=backend)
# make sure resid is used for replacement to work result_fused = model_fused(x)
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
x2, resid, self.norm_weight[1], self.eps
)
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1]) model_unfused = torch.compile(model, backend=backend2)
result_unfused = model_unfused(x)
y3, resid = rocm_aiter_ops.rms_norm2d_with_add( if dtype == torch.float16:
x3, resid, self.norm_weight[2], self.eps ATOL, RTOL = (2e-3, 2e-3)
) else:
ATOL, RTOL = (1e-2, 1e-2)
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2]) torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
y4, resid = rocm_aiter_ops.rms_norm2d_with_add( assert fusion_pass.matched_count == 3
x4, resid, self.norm_weight[3], self.eps backend.check_before_ops(model.ops_in_model_before())
) backend.check_after_ops(model.ops_in_model_after())
return y4
def ops_in_model_before(self): return backend, backend2
return [
torch.ops.vllm.rocm_aiter_rms_norm,
torch.ops.vllm.rocm_aiter_group_fp8_quant,
]
def ops_in_model_before_partial(self):
return []
def ops_in_model_after(self):
return [
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
...@@ -223,11 +260,8 @@ class TestRmsnormGroupFp8QuantModel(torch.nn.Module): ...@@ -223,11 +260,8 @@ class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("group_shape", GROUP_SHAPES) @pytest.mark.parametrize("group_shape", GROUP_SHAPES)
@pytest.mark.parametrize( @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op", @pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
list(itertools.product([TestModel], [True, False], [True, False]))
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
)
# cuda_force_torch used to test torch code path on platforms that # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True. # cutlass_fp8_supported() == True.
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -242,23 +276,13 @@ def test_fusion_rmsnorm_quant( ...@@ -242,23 +276,13 @@ def test_fusion_rmsnorm_quant(
num_tokens, num_tokens,
eps, eps,
group_shape, group_shape,
model_class,
enable_rms_norm_custom_op, enable_rms_norm_custom_op,
enable_quant_fp8_custom_op, enable_quant_fp8_custom_op,
cuda_force_torch, cuda_force_torch,
): ):
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
if not enable_quant_fp8_custom_op and group_shape.is_per_group(): if not enable_quant_fp8_custom_op and group_shape.is_per_group():
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
# Skip test for 64-bit group shape when running with cutlass or deepgemm
if group_shape == GroupShape(1, 64) and ( if group_shape == GroupShape(1, 64) and (
cutlass_block_fp8_supported() or is_deep_gemm_supported() cutlass_block_fp8_supported() or is_deep_gemm_supported()
): ):
...@@ -269,6 +293,7 @@ def test_fusion_rmsnorm_quant( ...@@ -269,6 +293,7 @@ def test_fusion_rmsnorm_quant(
custom_ops.append("+rms_norm") custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op: if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8") custom_ops.append("+quant_fp8")
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype), model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
...@@ -279,60 +304,97 @@ def test_fusion_rmsnorm_quant( ...@@ -279,60 +304,97 @@ def test_fusion_rmsnorm_quant(
), ),
), ),
) )
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
if model_class is TestRmsnormGroupFp8QuantModel:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
)
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config) with vllm.config.set_current_vllm_config(vllm_config):
else: # Setup device before model creation
fusion_pass = RMSNormQuantFusionPass(vllm_config) torch.set_default_device("cuda")
cleanup_pass = PostCleanupPass(vllm_config) torch.set_default_dtype(dtype)
torch.manual_seed(1)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) maybe_create_device_identity()
backend2 = TestBackend(noop_pass, cleanup_pass)
model = model_class( fusion_pass = RMSNormQuantFusionPass(vllm_config)
model = TestModel(
hidden_size=hidden_size, hidden_size=hidden_size,
eps=eps, eps=eps,
group_shape=group_shape, group_shape=group_shape,
use_aiter=False,
cuda_force_torch=cuda_force_torch, cuda_force_torch=cuda_force_torch,
) )
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
model_unfused = torch.compile(model, backend=backend2)
result_unfused = model_unfused(x)
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 3 backend, _ = _run_fusion_test(
backend.check_before_ops(model.ops_in_model_before()) model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
)
backend.check_before_ops( backend.check_before_ops(
model.ops_in_model_before_partial(), fully_replaced=False model.ops_in_model_before_partial(), fully_replaced=False
) )
backend.check_after_ops(model.ops_in_model_after())
# If RMSNorm custom op is disabled (native/torch impl used), # If RMSNorm custom op is disabled (native/torch impl used),
# there's a risk that the fused add doesn't get included in the # there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant. # replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add). # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if ( if not enable_rms_norm_custom_op:
not enable_rms_norm_custom_op
and model_class is not TestRmsnormGroupFp8QuantModel
):
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7 assert n_add_nodes(backend.graph_pre_pass) == 7
assert n_add_nodes(backend.graph_post_pass) == 2 assert n_add_nodes(backend.graph_post_pass) == 2
GROUP_SHAPE_QUANT_OPS_MATCHS = [
(GroupShape.PER_TOKEN, True),
(GroupShape.PER_TOKEN, False),
(GroupShape(1, 128), True),
]
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
"group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS
)
@pytest.mark.skipif(
(not current_platform.is_rocm() or not IS_AITER_FOUND),
reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_quant(
dtype: torch.dtype,
hidden_size: int,
num_tokens: int,
eps: float,
group_shape: GroupShape,
use_aiter_quant_op: bool,
monkeypatch: pytest.MonkeyPatch,
):
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm", "+quant_fp8"],
pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
),
)
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity()
fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
model = TestModel(
hidden_size=hidden_size,
eps=eps,
group_shape=group_shape,
use_aiter=True,
use_aiter_quant_op=use_aiter_quant_op,
)
_run_fusion_test(
model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
)
...@@ -9,8 +9,6 @@ from tests.compile.backend import LazyInitPass, TestBackend ...@@ -9,8 +9,6 @@ from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import flat_product from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
...@@ -37,6 +35,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -37,6 +35,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -305,8 +305,12 @@ def test_attention_quant_pattern( ...@@ -305,8 +305,12 @@ def test_attention_quant_pattern(
model_class: type[AttentionQuantPatternModel], model_class: type[AttentionQuantPatternModel],
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
dist_init, dist_init,
monkeypatch,
use_fresh_inductor_cache,
): ):
"""Test AttentionStaticQuantPattern fusion pass""" """Test AttentionStaticQuantPattern fusion pass"""
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if backend == AttentionBackendEnum.FLASHINFER and ( if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer() not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
): ):
...@@ -363,13 +367,15 @@ def test_attention_quant_pattern( ...@@ -363,13 +367,15 @@ def test_attention_quant_pattern(
vllm_config=vllm_config_unfused, vllm_config=vllm_config_unfused,
) )
model_unfused = model_unfused.to(device) model_unfused = model_unfused.to(device)
result_unfused_0 = model_unfused(q, k, v) # noqa: F841 HACK: See #131044
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
# Run model directly without fusion # Run model directly without fusion
# Still compile so query QuantFP8 has closer numerics # Still compile so query QuantFP8 has closer numerics
result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v) compiled_unfused = torch.compile(model_unfused, fullgraph=True)
result_unfused = compiled_unfused(q, k, v)
# Run model with attn fusion enabled # Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig( vllm_config.compilation_config.pass_config = PassConfig(
...@@ -399,24 +405,26 @@ def test_attention_quant_pattern( ...@@ -399,24 +405,26 @@ def test_attention_quant_pattern(
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# HACK: See https://github.com/vllm-project/vllm/issues/31044
result_fused_0 = model_fused(q, k, v) # noqa: F841
# Compile model with fusion enabled # Compile model with fusion enabled
model_compiled = torch.compile( compiled_fused = torch.compile(
model_fused, backend=test_backend, fullgraph=True model_fused, backend=test_backend, fullgraph=True
) )
assert model_compiled.attn._o_scale_float is None assert compiled_fused.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v) result_fused = compiled_fused(q, k, v)
if backend == AttentionBackendEnum.FLASHINFER: if backend == AttentionBackendEnum.FLASHINFER:
# With the Flashinfer backend after the 1st round of the forward # With the Flashinfer backend after the 1st round of the forward
# pass, output quant scale should be loaded into the attn layer's # pass, output quant scale should be loaded into the attn layer's
# _o_scale_float, the 2nd round should reuse the loaded # _o_scale_float, the 2nd round should reuse the loaded
# _o_scale_float # _o_scale_float
assert model_compiled.attn._o_scale_float is not None assert compiled_fused.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v) result_fused_2 = compiled_fused(q, k, v)
assert model_compiled.attn._o_scale_float is not None assert compiled_fused.attn._o_scale_float is not None
torch.testing.assert_close( torch.testing.assert_close(
result_unfused, result_fused_2, atol=1e-2, rtol=1e-2 result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
...@@ -474,4 +482,4 @@ def test_attention_quant_pattern( ...@@ -474,4 +482,4 @@ def test_attention_quant_pattern(
) )
# Check that results are close # Check that results are close
torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2) torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)
{
"state-spaces/mamba-130m-hf": {
"architectures": [
"MambaForCausalLM"
],
"model_type": "mamba",
"text_model_type": "mamba",
"hidden_size": 768,
"total_num_hidden_layers": 24,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 50280,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.float32"
},
"mistralai/Mamba-Codestral-7B-v0.1": {
"architectures": [
"Mamba2ForCausalLM"
],
"model_type": "mamba",
"text_model_type": "mamba",
"hidden_size": 4096,
"total_num_hidden_layers": 64,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 32768,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": {
"architectures": [
"Terratorch"
],
"model_type": "timm_wrapper",
"text_model_type": "timm_wrapper",
"hidden_size": 0,
"total_num_hidden_layers": 0,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 0,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": true,
"dtype": "torch.float32"
},
"tiiuae/falcon-mamba-7b-instruct": {
"architectures": [
"FalconMambaForCausalLM"
],
"model_type": "falcon_mamba",
"text_model_type": "falcon_mamba",
"hidden_size": 4096,
"total_num_hidden_layers": 64,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 65024,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"Zyphra/Zamba2-7B-instruct": {
"architectures": [
"Zamba2ForCausalLM"
],
"model_type": "zamba2",
"text_model_type": "zamba2",
"hidden_size": 3584,
"total_num_hidden_layers": 81,
"total_num_attention_heads": 32,
"head_size": 224,
"vocab_size": 32000,
"total_num_kv_heads": 32,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"mosaicml/mpt-7b": {
"architectures": [
"MPTForCausalLM"
],
"model_type": "mpt",
"text_model_type": "mpt",
"hidden_size": 4096,
"total_num_hidden_layers": 32,
"total_num_attention_heads": 32,
"head_size": 128,
"vocab_size": 50432,
"total_num_kv_heads": 32,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"databricks/dbrx-instruct": {
"architectures": [
"DbrxForCausalLM"
],
"model_type": "dbrx",
"text_model_type": "dbrx",
"hidden_size": 6144,
"total_num_hidden_layers": 40,
"total_num_attention_heads": 48,
"head_size": 128,
"vocab_size": 100352,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"tiiuae/falcon-7b": {
"architectures": [
"FalconForCausalLM"
],
"model_type": "falcon",
"text_model_type": "falcon",
"hidden_size": 4544,
"total_num_hidden_layers": 32,
"total_num_attention_heads": 71,
"head_size": 64,
"vocab_size": 65024,
"total_num_kv_heads": 1,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"tiiuae/falcon-40b": {
"architectures": [
"FalconForCausalLM"
],
"model_type": "falcon",
"text_model_type": "falcon",
"hidden_size": 8192,
"total_num_hidden_layers": 60,
"total_num_attention_heads": 128,
"head_size": 64,
"vocab_size": 65024,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"luccafong/deepseek_mtp_main_random": {
"architectures": [
"DeepseekV3ForCausalLM"
],
"model_type": "deepseek_v3",
"text_model_type": "deepseek_v3",
"hidden_size": 2560,
"total_num_hidden_layers": 5,
"total_num_attention_heads": 32,
"head_size": 576,
"vocab_size": 129280,
"total_num_kv_heads": 32,
"num_experts": 72,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"luccafong/deepseek_mtp_draft_random": {
"architectures": [
"DeepseekV3ForCausalLM"
],
"model_type": "deepseek_v3",
"text_model_type": "deepseek_v3",
"hidden_size": 2560,
"total_num_hidden_layers": 10,
"total_num_attention_heads": 32,
"head_size": 576,
"vocab_size": 129280,
"total_num_kv_heads": 32,
"num_experts": 72,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"Qwen/Qwen3-Next-80B-A3B-Instruct": {
"architectures": [
"Qwen3NextForCausalLM"
],
"model_type": "qwen3_next",
"text_model_type": "qwen3_next",
"hidden_size": 2048,
"total_num_hidden_layers": 48,
"total_num_attention_heads": 16,
"head_size": 256,
"vocab_size": 151936,
"total_num_kv_heads": 2,
"num_experts": 512,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"tiny-random/qwen3-next-moe": {
"architectures": [
"Qwen3NextForCausalLM"
],
"model_type": "qwen3_next",
"text_model_type": "qwen3_next",
"hidden_size": 8,
"total_num_hidden_layers": 4,
"total_num_attention_heads": 16,
"head_size": 32,
"vocab_size": 151936,
"total_num_kv_heads": 8,
"num_experts": 32,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"zai-org/GLM-4.5": {
"architectures": [
"Glm4MoeForCausalLM"
],
"model_type": "glm4_moe",
"text_model_type": "glm4_moe",
"hidden_size": 5120,
"total_num_hidden_layers": 92,
"total_num_attention_heads": 96,
"head_size": 128,
"vocab_size": 151552,
"total_num_kv_heads": 8,
"num_experts": 160,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"baidu/ERNIE-4.5-21B-A3B-PT": {
"architectures": [
"Ernie4_5_MoeForCausalLM"
],
"model_type": "ernie4_5_moe",
"text_model_type": "ernie4_5_moe",
"hidden_size": 2560,
"total_num_hidden_layers": 28,
"total_num_attention_heads": 20,
"head_size": 128,
"vocab_size": 103424,
"total_num_kv_heads": 4,
"num_experts": 64,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"lmsys/gpt-oss-20b-bf16": {
"architectures": [
"GptOssForCausalLM"
],
"model_type": "gpt_oss",
"text_model_type": "gpt_oss",
"hidden_size": 2880,
"total_num_hidden_layers": 24,
"total_num_attention_heads": 64,
"head_size": 64,
"vocab_size": 201088,
"total_num_kv_heads": 8,
"num_experts": 32,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"deepseek-ai/DeepSeek-V3.2-Exp": {
"architectures": [
"DeepseekV32ForCausalLM"
],
"model_type": "deepseek_v32",
"text_model_type": "deepseek_v32",
"hidden_size": 7168,
"total_num_hidden_layers": 61,
"total_num_attention_heads": 128,
"head_size": 576,
"vocab_size": 129280,
"total_num_kv_heads": 128,
"num_experts": 256,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"meta-llama/Llama-4-Scout-17B-16E-Instruct": {
"architectures": [
"Llama4ForConditionalGeneration"
],
"model_type": "llama4",
"text_model_type": "llama4_text",
"hidden_size": 5120,
"total_num_hidden_layers": 48,
"total_num_attention_heads": 40,
"head_size": 128,
"vocab_size": 202048,
"total_num_kv_heads": 8,
"num_experts": 16,
"is_deepseek_mla": false,
"is_multimodal_model": true,
"dtype": "torch.bfloat16"
},
"nvidia/Llama-3_3-Nemotron-Super-49B-v1": {
"architectures": [
"DeciLMForCausalLM"
],
"model_type": "nemotron-nas",
"text_model_type": "nemotron-nas",
"hidden_size": 8192,
"total_num_hidden_layers": 80,
"total_num_attention_heads": 64,
"head_size": 128,
"vocab_size": 128256,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"XiaomiMiMo/MiMo-7B-RL": {
"architectures": [
"MiMoForCausalLM"
],
"model_type": "mimo",
"text_model_type": "mimo",
"hidden_size": 4096,
"total_num_hidden_layers": 36,
"total_num_attention_heads": 32,
"head_size": 128,
"vocab_size": 151680,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"meituan-longcat/LongCat-Flash-Chat": {
"architectures": [
"LongcatFlashForCausalLM"
],
"model_type": "longcat_flash",
"text_model_type": "longcat_flash",
"hidden_size": 6144,
"total_num_hidden_layers": 28,
"total_num_attention_heads": 64,
"head_size": 576,
"vocab_size": 131072,
"total_num_kv_heads": 64,
"num_experts": 512,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "torch.float32"
}
}
{
"abhigoyal/vllm-medusa-llama-68m-random": {
"architectures": [
"MedusaModel"
],
"model_type": "medusa",
"text_model_type": "medusa",
"hidden_size": 768,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 0,
"head_size": "Error: integer division or modulo by zero",
"vocab_size": 32000,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.float32"
},
"luccafong/deepseek_mtp_draft_random": {
"architectures": [
"DeepSeekMTPModel"
],
"model_type": "deepseek_mtp",
"text_model_type": "deepseek_mtp",
"hidden_size": 2560,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
"head_size": 576,
"vocab_size": 129280,
"total_num_kv_heads": 32,
"num_experts": 72,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"eagle618/eagle-deepseek-v3-random": {
"architectures": [
"EagleDeepSeekMTPModel"
],
"model_type": "eagle",
"text_model_type": "deepseek_mtp",
"hidden_size": 2560,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
"head_size": 576,
"vocab_size": 129280,
"total_num_kv_heads": 32,
"num_experts": 72,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "bfloat16"
},
"yuhuili/EAGLE-LLaMA3-Instruct-8B": {
"architectures": [
"EagleLlamaForCausalLM"
],
"model_type": "eagle",
"text_model_type": "llama",
"hidden_size": 4096,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
"head_size": 128,
"vocab_size": 128256,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "float16"
},
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": {
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "eagle",
"text_model_type": "llama",
"hidden_size": 4096,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
"head_size": 128,
"vocab_size": 128256,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "float16"
}
}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ModelArchitectureConfig and its integration with ModelConfig."""
import json
from pathlib import Path
import pytest
from vllm.config import ModelConfig, ParallelConfig, SpeculativeConfig
from vllm.transformers_utils.model_arch_config_convertor import (
ModelArchConfigConvertorBase,
)
BASE_TRUST_REMOTE_CODE_MODELS = {
"nvidia/Llama-3_3-Nemotron-Super-49B-v1",
"XiaomiMiMo/MiMo-7B-RL",
# Excluded: Not available online right now
# "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1",
"meituan-longcat/LongCat-Flash-Chat",
}
BASE_MODELS_TO_TEST = [
"state-spaces/mamba-130m-hf",
"mistralai/Mamba-Codestral-7B-v0.1",
# Excluded: terratorch/torchgeo version mismatch in CPU CI environment
# (NonGeoDataset import error). Tested in model initialization tests.
# "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"Zyphra/Zamba2-7B-instruct",
# FIXME: mosaicml/mpt-7b has been deleted
# "mosaicml/mpt-7b",
# FIXME: databricks/dbrx-instruct has been deleted
# "databricks/dbrx-instruct",
"tiiuae/falcon-7b",
"tiiuae/falcon-40b",
"luccafong/deepseek_mtp_main_random",
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"tiny-random/qwen3-next-moe",
"zai-org/GLM-4.5",
"baidu/ERNIE-4.5-21B-A3B-PT",
# Models using base convertor
"lmsys/gpt-oss-20b-bf16",
"deepseek-ai/DeepSeek-V3.2-Exp",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
] + list(BASE_TRUST_REMOTE_CODE_MODELS)
# (target_model, draft_model, trust_remote_code)
SPECULATIVE_MODELS = [
("JackFram/llama-68m", "abhigoyal/vllm-medusa-llama-68m-random", False),
("luccafong/deepseek_mtp_main_random", "luccafong/deepseek_mtp_draft_random", True),
("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True),
("meta-llama/Meta-Llama-3-8B-Instruct", "yuhuili/EAGLE-LLaMA3-Instruct-8B", True),
("meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", True),
]
def _load_groundtruth(filename: str) -> dict:
"""Load groundtruth JSON from the test directory."""
groundtruth_path = Path(__file__).parent / filename
with open(groundtruth_path) as f:
return json.load(f)
def _assert_model_arch_config(
model_config, expected: dict, check_head_size: bool = True
):
"""Assert model_arch_config matches expected values."""
model_arch_config = model_config.model_arch_config
assert model_arch_config.architectures == expected["architectures"]
assert model_arch_config.model_type == expected["model_type"]
assert model_arch_config.text_model_type == expected["text_model_type"]
assert model_arch_config.hidden_size == expected["hidden_size"]
assert (
model_arch_config.total_num_hidden_layers == expected["total_num_hidden_layers"]
)
assert (
model_arch_config.total_num_attention_heads
== expected["total_num_attention_heads"]
)
assert model_arch_config.vocab_size == expected["vocab_size"]
assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"]
assert model_arch_config.num_experts == expected["num_experts"]
assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"]
torch_dtype = ModelArchConfigConvertorBase.get_torch_dtype(
model_config.hf_config, model_config.model, revision=model_config.revision
)
assert str(torch_dtype) == expected["dtype"]
if check_head_size:
assert model_arch_config.head_size == expected["head_size"]
def _assert_model_config_methods(
model_config, expected: dict, check_head_size: bool = True
):
"""Assert model_config methods return expected values."""
assert model_config.architectures == expected["architectures"]
assert model_config.get_vocab_size() == expected["vocab_size"]
assert model_config.get_hidden_size() == expected["hidden_size"]
assert model_config.get_total_num_kv_heads() == expected["total_num_kv_heads"]
assert model_config.get_num_experts() == expected["num_experts"]
assert (
model_config.get_total_num_hidden_layers()
== expected["total_num_hidden_layers"]
)
if check_head_size:
assert model_config.get_head_size() == expected["head_size"]
@pytest.mark.parametrize("model", BASE_MODELS_TO_TEST)
def test_base_model_arch_config(model: str):
"""Test model architecture config for base models."""
groundtruth = _load_groundtruth("base_model_arch_groundtruth.json")
expected = groundtruth[model]
model_config = ModelConfig(
model, trust_remote_code=model in BASE_TRUST_REMOTE_CODE_MODELS
)
_assert_model_arch_config(model_config, expected)
_assert_model_config_methods(model_config, expected)
@pytest.mark.parametrize(
"target_model,draft_model,trust_remote_code", SPECULATIVE_MODELS
)
def test_draft_model_arch_config(
target_model: str, draft_model: str, trust_remote_code: bool
):
"""Test model architecture config for draft/speculative models."""
groundtruth = _load_groundtruth("draft_model_arch_groundtruth.json")
expected = groundtruth[draft_model]
target_model_config = ModelConfig(target_model, trust_remote_code=trust_remote_code)
speculative_config = SpeculativeConfig(
model=draft_model,
num_speculative_tokens=1,
target_model_config=target_model_config,
target_parallel_config=ParallelConfig(),
)
model_config = speculative_config.draft_model_config
# For medusa models, head_size may cause division by zero before
# model_arch_config was introduced, so we conditionally check it
check_head_size = isinstance(expected["head_size"], int)
_assert_model_arch_config(model_config, expected, check_head_size=check_head_size)
_assert_model_config_methods(
model_config, expected, check_head_size=check_head_size
)
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import pytest import pytest
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MultiModalConfig from vllm.config.multimodal import MultiModalConfig
from vllm.v1.attention.backends.registry import AttentionBackendEnum
def test_mm_encoder_attn_backend_str_conversion(): def test_mm_encoder_attn_backend_str_conversion():
......
...@@ -47,7 +47,11 @@ from transformers import ( ...@@ -47,7 +47,11 @@ from transformers import (
) )
from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs from tests.models.utils import (
TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
softmax,
)
from vllm import LLM, SamplingParams, envs from vllm import LLM, SamplingParams, envs
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
...@@ -189,6 +193,17 @@ def dist_init(): ...@@ -189,6 +193,17 @@ def dist_init():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.fixture
def default_vllm_config():
"""Set a default VllmConfig for tests that directly test CustomOps or pathways
that use get_current_vllm_config() outside of a full engine context.
"""
from vllm.config import VllmConfig, set_current_vllm_config
with set_current_vllm_config(VllmConfig()):
yield
@pytest.fixture() @pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool: def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture. """Allow subdirectories to skip global cleanup by overriding this fixture.
...@@ -414,7 +429,7 @@ class HfRunner: ...@@ -414,7 +429,7 @@ class HfRunner:
# don't put this import at the top level # don't put this import at the top level
# it will call torch.cuda.device_count() # it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401 from transformers import AutoProcessor
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
model_name, model_name,
...@@ -517,7 +532,7 @@ class HfRunner: ...@@ -517,7 +532,7 @@ class HfRunner:
elif problem_type == "multi_label_classification": elif problem_type == "multi_label_classification":
logits = output.logits.sigmoid()[0].tolist() logits = output.logits.sigmoid()[0].tolist()
else: else:
logits = output.logits.softmax(dim=-1)[0].tolist() logits = softmax(output.logits)[0].tolist()
outputs.append(logits) outputs.append(logits)
return outputs return outputs
...@@ -685,6 +700,7 @@ class HfRunner: ...@@ -685,6 +700,7 @@ class HfRunner:
images: PromptImageInput | None = None, images: PromptImageInput | None = None,
audios: PromptAudioInput | None = None, audios: PromptAudioInput | None = None,
videos: PromptVideoInput | None = None, videos: PromptVideoInput | None = None,
use_cache: bool = True,
**kwargs: Any, **kwargs: Any,
) -> list[TokensTextLogprobs]: ) -> list[TokensTextLogprobs]:
all_inputs = self.get_inputs( all_inputs = self.get_inputs(
...@@ -698,7 +714,7 @@ class HfRunner: ...@@ -698,7 +714,7 @@ class HfRunner:
for inputs in all_inputs: for inputs in all_inputs:
output: "GenerateOutput" = self.model.generate( output: "GenerateOutput" = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
use_cache=True, use_cache=use_cache,
do_sample=False, do_sample=False,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
output_hidden_states=True, output_hidden_states=True,
......
...@@ -219,14 +219,12 @@ def _test_cp_gsm8k( ...@@ -219,14 +219,12 @@ def _test_cp_gsm8k(
] ]
) )
server_env = {}
if attn_backend: if attn_backend:
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend server_args.append(f"--attention-backend={attn_backend}")
with RemoteOpenAIServer( with RemoteOpenAIServer(
model_id, model_id,
server_args, server_args,
env_dict=server_env,
max_wait_seconds=720, max_wait_seconds=720,
) as remote_server: ) as remote_server:
host = f"http://{remote_server.host}" host = f"http://{remote_server.host}"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest import pytest
import torch import torch
...@@ -310,3 +311,143 @@ if __name__ == "__main__": ...@@ -310,3 +311,143 @@ if __name__ == "__main__":
print(phy2log) print(phy2log)
test_basic_rebalance() test_basic_rebalance()
def _make_phy_replicas_idx_from_phy2log(phy2log: np.ndarray) -> np.ndarray:
"""Create replicas indices mapping from phy2log."""
pr = np.zeros_like(phy2log, dtype=np.int64)
for layer in range(phy2log.shape[0]):
seen: dict[int, int] = {}
row = phy2log[layer].tolist()
for i, expert in enumerate(row):
r = seen.get(expert, 0)
pr[layer, i] = r
seen[expert] = r + 1
return pr
def _validate_intragpu_rearrangement(
old_global_expert_indices: np.ndarray,
new_phy2log: np.ndarray,
new_phy_replicas_idx: np.ndarray,
post_phy2log: np.ndarray,
post_phy_replicas_idx: np.ndarray,
num_ranks: int,
slots_per_gpu: int,
):
# Per-GPU checks
for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
old_seg = old_global_expert_indices[0, start:end]
new_seg = new_phy2log[0, start:end]
new_rnk = new_phy_replicas_idx[0, start:end]
post_seg = post_phy2log[0, start:end]
post_rnk = post_phy_replicas_idx[0, start:end]
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
def sorted_pairs(seg, rnk):
pairs = list(zip(seg.tolist(), rnk.tolist()))
pairs.sort()
return pairs
assert sorted_pairs(post_seg, post_rnk) == sorted_pairs(new_seg, new_rnk), (
f"Per-GPU pairs of (expert,rank) must match new mapping for GPU {gpu_idx}"
)
# For experts that remain on the same GPU, the old slot is preserved
# for at least one occurrence; rank at that slot must be valid for that expert
old_list = old_seg.tolist()
new_list = new_seg.tolist()
post_list = post_seg.tolist()
remained = set(old_list) & set(new_list)
new_ranks_for_expert: dict[int, list[int]] = {}
for v, r in zip(new_list, new_rnk.tolist()):
new_ranks_for_expert.setdefault(v, []).append(r)
for expert in remained:
old_pos = old_list.index(expert)
assert post_list[old_pos] == expert, (
f"Expert {expert} on GPU {gpu_idx} should stay at old slot {old_pos}"
)
# Rank at preserved slot must be one of the ranks
# the expert has in new mapping
assert post_rnk.tolist()[old_pos] in new_ranks_for_expert[expert], (
f"Rank for expert {expert} at preserved slot on GPU {gpu_idx} "
"must come from new mapping"
)
@pytest.mark.parametrize(
"num_ranks, slots_per_gpu, old_phy2log, new_phy2log",
[
pytest.param(
# Setup: 2 GPUs, 4 slots each, 1 layer
# Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7]
# New mapping shuffles within GPU0 and brings 4,5 into GPU0.
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
2,
4,
np.array([[0, 1, 2, 3, 4, 5, 6, 7]]),
np.array([[1, 5, 0, 4, 6, 2, 7, 3]]),
id="simple",
),
pytest.param(
# Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer
# Old mapping:
# GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated)
# GPU1 -> [4, 5, 6, 1, 2]
# New mapping reorders within GPUs and moves some experts across GPUs,
# while still including duplicates:
# GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming)
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
2,
5,
np.array([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]),
np.array([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]),
id="duplicates",
),
pytest.param(
# Setup: 3 GPUs, 4 slots each (total 12 physical experts), 1 layer
# Old mapping:
# GPU0 -> [0, 1, 2, 3]
# GPU1 -> [0, 1, 2, 3]
# GPU2 -> [0, 1, 2, 3]
# New mapping decides to use one expert on 2 GPUs and shuffles
# experts on the third GPU,
# GPU0 new -> [0, 0, 0, 0]
# GPU1 new -> [0, 0, 0, 0]
# GPU2 new -> [1, 2, 3, 0]
3,
4,
np.array([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]),
id="skewed_expert",
),
],
)
def test_preserve_intragpu_slots(
num_ranks: int,
slots_per_gpu: int,
old_phy2log: torch.Tensor,
new_phy2log: torch.Tensor,
):
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots(
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log
)
# Shapes preserved
assert post_phy2log.shape == new_phy2log.shape
assert post_phy_replicas_idx.shape == phy_replicas_idx.shape
_validate_intragpu_rearrangement(
old_phy2log,
new_phy2log,
phy_replicas_idx,
post_phy2log,
post_phy_replicas_idx,
num_ranks,
slots_per_gpu,
)
...@@ -286,15 +286,17 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -286,15 +286,17 @@ def _test_async_transfer_layer_without_mtp_worker(
device, device,
old_indices, old_indices,
) )
old_indices_cpu = old_indices.cpu()
new_indices_cpu = new_indices.cpu()
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]] expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
cuda_stream = torch.cuda.Stream(device=device) cuda_stream = torch.cuda.Stream(device=device)
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
is_unchanged, is_received_locally, experts_recv_loc = asyncio.run( is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer( transfer_layer(
old_global_expert_indices=old_indices, old_global_expert_indices=old_indices_cpu,
new_global_expert_indices=new_indices, new_global_expert_indices=new_indices_cpu,
expert_weights=expert_weights, expert_weights=expert_weights,
expert_weights_buffer=expert_buffer, expert_weights_buffer=expert_buffer,
ep_group=ep_group, ep_group=ep_group,
...@@ -302,16 +304,15 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -302,16 +304,15 @@ def _test_async_transfer_layer_without_mtp_worker(
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
) )
) )
cuda_stream.synchronize() cuda_stream.synchronize()
move_from_buffer( move_from_buffer(
expert_weights=expert_weights[layer_idx], expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer, expert_weights_buffers=expert_buffer,
is_unchanged=is_unchanged, is_unchanged=is_unchanged,
is_received_locally=is_received_locally, is_received_locally=is_received_locally,
experts_recv_loc=experts_recv_loc, recv_metadata=recv_metadata,
new_indices=new_indices[layer_idx].tolist(), new_indices=new_indices_cpu[layer_idx].numpy(),
ep_group=ep_group, ep_rank=ep_rank,
) )
verify_expert_weights_after_shuffle( verify_expert_weights_after_shuffle(
......
...@@ -21,23 +21,21 @@ from ..utils import compare_two_settings, create_new_process_for_each_test, mode ...@@ -21,23 +21,21 @@ from ..utils import compare_two_settings, create_new_process_for_each_test, mode
) )
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_pp_cudagraph( def test_pp_cudagraph(
monkeypatch: pytest.MonkeyPatch,
PP_SIZE: int, PP_SIZE: int,
MODEL_NAME: str, MODEL_NAME: str,
ATTN_BACKEND: LiteralString, ATTN_BACKEND: LiteralString,
): ):
with monkeypatch.context() as m: cudagraph_args = [
cudagraph_args = [ # use half precision for speed and memory savings in CI environment
# use half precision for speed and memory savings in CI environment "--dtype",
"--dtype", "float16",
"float16", "--pipeline-parallel-size",
"--pipeline-parallel-size", str(PP_SIZE),
str(PP_SIZE), "--distributed-executor-backend",
"--distributed-executor-backend", "mp",
"mp", f"--attention-backend={ATTN_BACKEND}",
] ]
m.setenv("VLLM_ATTENTION_BACKEND", ATTN_BACKEND)
eager_args = cudagraph_args + ["--enforce-eager"] eager_args = cudagraph_args + ["--enforce-eager"]
compare_two_settings(MODEL_NAME, eager_args, cudagraph_args) compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)
...@@ -9,7 +9,7 @@ from typing import Annotated, Literal ...@@ -9,7 +9,7 @@ from typing import Annotated, Literal
import pytest import pytest
from vllm.config import CompilationConfig, config from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import ( from vllm.engine.arg_utils import (
EngineArgs, EngineArgs,
contains_type, contains_type,
...@@ -298,6 +298,139 @@ def test_compilation_config(): ...@@ -298,6 +298,139 @@ def test_compilation_config():
) )
def test_attention_config():
from vllm.v1.attention.backends.registry import AttentionBackendEnum
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# default value
args = parser.parse_args([])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config == AttentionConfig()
# set backend via dot notation
args = parser.parse_args(["--attention-config.backend", "FLASH_ATTN"])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
# set backend via --attention-backend shorthand
args = parser.parse_args(["--attention-backend", "FLASHINFER"])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_backend is not None
assert engine_args.attention_backend == "FLASHINFER"
# set all fields via dot notation
args = parser.parse_args(
[
"--attention-config.backend",
"FLASH_ATTN",
"--attention-config.flash_attn_version",
"3",
"--attention-config.use_prefill_decode_attention",
"true",
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
"16",
"--attention-config.use_cudnn_prefill",
"true",
"--attention-config.use_trtllm_ragged_deepseek_prefill",
"true",
"--attention-config.use_trtllm_attention",
"true",
"--attention-config.disable_flashinfer_prefill",
"true",
"--attention-config.disable_flashinfer_q_quantization",
"true",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
assert engine_args.attention_config.flash_attn_version == 3
assert engine_args.attention_config.use_prefill_decode_attention is True
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
assert engine_args.attention_config.use_cudnn_prefill is True
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
assert engine_args.attention_config.use_trtllm_attention is True
assert engine_args.attention_config.disable_flashinfer_prefill is True
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
# set to string form of a dict with all fields
args = parser.parse_args(
[
"--attention-config="
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
'"use_prefill_decode_attention": false, '
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
'"use_cudnn_prefill": false, '
'"use_trtllm_ragged_deepseek_prefill": false, '
'"use_trtllm_attention": false, '
'"disable_flashinfer_prefill": false, '
'"disable_flashinfer_q_quantization": false}',
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASHINFER"
assert engine_args.attention_config.flash_attn_version == 2
assert engine_args.attention_config.use_prefill_decode_attention is False
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
assert engine_args.attention_config.use_cudnn_prefill is False
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
assert engine_args.attention_config.use_trtllm_attention is False
assert engine_args.attention_config.disable_flashinfer_prefill is False
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
# test --attention-backend flows into VllmConfig.attention_config
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-backend",
"FLASH_ATTN",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
vllm_config = engine_args.create_engine_config()
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASH_ATTN
# test --attention-config.backend flows into VllmConfig.attention_config
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-config.backend",
"FLASHINFER",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
vllm_config = engine_args.create_engine_config()
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASHINFER
# test --attention-backend and --attention-config.backend are mutually exclusive
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-backend",
"FLASH_ATTN",
"--attention-config.backend",
"FLASHINFER",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
with pytest.raises(ValueError, match="mutually exclusive"):
engine_args.create_engine_config()
def test_prefix_cache_default(): def test_prefix_cache_default():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([]) args = parser.parse_args([])
...@@ -378,6 +511,16 @@ def test_human_readable_model_len(): ...@@ -378,6 +511,16 @@ def test_human_readable_model_len():
args = parser.parse_args(["--max-model-len", "10.2123451234567t"]) args = parser.parse_args(["--max-model-len", "10.2123451234567t"])
assert args.max_model_len == 10212345123456 assert args.max_model_len == 10212345123456
# Special value -1 for auto-fit to GPU memory
args = parser.parse_args(["--max-model-len", "-1"])
assert args.max_model_len == -1
# 'auto' is an alias for -1
args = parser.parse_args(["--max-model-len", "auto"])
assert args.max_model_len == -1
args = parser.parse_args(["--max-model-len", "AUTO"])
assert args.max_model_len == -1
# Invalid (do not allow decimals with binary multipliers) # Invalid (do not allow decimals with binary multipliers)
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]: for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
with pytest.raises(ArgumentError): with pytest.raises(ArgumentError):
......
...@@ -15,9 +15,9 @@ import requests ...@@ -15,9 +15,9 @@ import requests
from prometheus_client.parser import text_string_to_metric_families from prometheus_client.parser import text_string_to_metric_families
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.conftest import LocalAssetServer
from tests.utils import RemoteOpenAIServer
from vllm import version from vllm import version
from ...conftest import LocalAssetServer
from ...utils import RemoteOpenAIServer, models_path_prefix from ...utils import RemoteOpenAIServer, models_path_prefix
MODELS = { MODELS = {
......
...@@ -5,6 +5,30 @@ import pytest ...@@ -5,6 +5,30 @@ import pytest
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
def add_attention_backend(server_args, attention_config):
"""Append attention backend CLI arg if specified.
Args:
server_args: List of server arguments to extend in-place.
attention_config: Dict with 'backend' key, or None.
"""
if attention_config and "backend" in attention_config:
server_args.extend(["--attention-backend", attention_config["backend"]])
@pytest.fixture(scope="module")
def rocm_aiter_fa_attention():
"""Return attention config for transcription/translation tests on ROCm.
On ROCm, audio tests require ROCM_AITER_FA attention backend.
"""
from vllm.platforms import current_platform
if current_platform.is_rocm():
return {"backend": "ROCM_AITER_FA"}
return None
@pytest.fixture @pytest.fixture
def mary_had_lamb(): def mary_had_lamb():
path = AudioAsset("mary_had_lamb").get_local_path() path = AudioAsset("mary_had_lamb").get_local_path()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment