Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
...@@ -13,12 +13,15 @@ import pytest ...@@ -13,12 +13,15 @@ import pytest
import torch import torch
from vllm import LLM from vllm import LLM
from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.llm_engine import LLMEngine
from ..conftest import HfRunner, VllmRunner from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
ATTN_BACKEND = ["ROCM_ATTN"] if current_platform.is_rocm() else ["FLASH_ATTN"]
MODELS = [ MODELS = [
"hmellor/tiny-random-Gemma2ForCausalLM", "hmellor/tiny-random-Gemma2ForCausalLM",
"meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
...@@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs( ...@@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ATTN_BACKEND)
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("async_scheduling", [True, False]) @pytest.mark.parametrize("async_scheduling", [True, False])
......
...@@ -260,13 +260,18 @@ def test_deep_sleep_fp8_kvcache(): ...@@ -260,13 +260,18 @@ def test_deep_sleep_fp8_kvcache():
llm.sleep(level=2) llm.sleep(level=2)
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
assert used_bytes < 3 * GiB_bytes
# Rocm uses more memory for CudaGraphs, so we add 2 GiB more for the threshold
rocm_extra_mem_bytes = 2 * GiB_bytes if current_platform.is_rocm() else 0
mem_threshold_after_sleep = 3 * GiB_bytes + rocm_extra_mem_bytes
assert used_bytes < mem_threshold_after_sleep
llm.wake_up(tags=["weights"]) llm.wake_up(tags=["weights"])
llm.collective_rpc("reload_weights") llm.collective_rpc("reload_weights")
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
assert used_bytes < 4 * GiB_bytes mem_threshold_after_wake_up = 4 * GiB_bytes + rocm_extra_mem_bytes
assert used_bytes < mem_threshold_after_wake_up
# now allocate kv cache and cuda graph memory # now allocate kv cache and cuda graph memory
llm.wake_up(tags=["kv_cache"]) llm.wake_up(tags=["kv_cache"])
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import tempfile
from pathlib import Path
import pytest
from vllm.benchmarks.sweep.param_sweep import ParameterSweep, ParameterSweepItem
class TestParameterSweepItem:
"""Test ParameterSweepItem functionality."""
@pytest.mark.parametrize(
"input_dict,expected",
[
(
{"compilation_config.use_inductor_graph_partition": False},
"--compilation-config.use_inductor_graph_partition=false",
),
(
{"compilation_config.use_inductor_graph_partition": True},
"--compilation-config.use_inductor_graph_partition=true",
),
(
{"compilation_config.use_inductor": False},
"--compilation-config.use_inductor=false",
),
(
{"compilation_config.use_inductor": True},
"--compilation-config.use_inductor=true",
),
],
)
def test_nested_boolean_params(self, input_dict, expected):
"""Test that nested boolean params use =true/false syntax."""
item = ParameterSweepItem.from_record(input_dict)
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
assert expected in cmd
@pytest.mark.parametrize(
"input_dict,expected",
[
({"enable_prefix_caching": False}, "--no-enable-prefix-caching"),
({"enable_prefix_caching": True}, "--enable-prefix-caching"),
({"disable_log_stats": False}, "--no-disable-log-stats"),
({"disable_log_stats": True}, "--disable-log-stats"),
],
)
def test_non_nested_boolean_params(self, input_dict, expected):
"""Test that non-nested boolean params use --no- prefix."""
item = ParameterSweepItem.from_record(input_dict)
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
assert expected in cmd
@pytest.mark.parametrize(
"compilation_config",
[
{"cudagraph_mode": "full", "mode": 2, "use_inductor_graph_partition": True},
{
"cudagraph_mode": "piecewise",
"mode": 3,
"use_inductor_graph_partition": False,
},
],
)
def test_nested_dict_value(self, compilation_config):
"""Test that nested dict values are serialized as JSON."""
item = ParameterSweepItem.from_record(
{"compilation_config": compilation_config}
)
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
assert "--compilation-config" in cmd
# The dict should be JSON serialized
idx = cmd.index("--compilation-config")
assert json.loads(cmd[idx + 1]) == compilation_config
@pytest.mark.parametrize(
"input_dict,expected_key,expected_value",
[
({"model": "test-model"}, "--model", "test-model"),
({"max_tokens": 100}, "--max-tokens", "100"),
({"temperature": 0.7}, "--temperature", "0.7"),
],
)
def test_string_and_numeric_values(self, input_dict, expected_key, expected_value):
"""Test that string and numeric values are handled correctly."""
item = ParameterSweepItem.from_record(input_dict)
cmd = item.apply_to_cmd(["vllm", "serve"])
assert expected_key in cmd
assert expected_value in cmd
@pytest.mark.parametrize(
"input_dict,expected_key,key_idx_offset",
[
({"max_tokens": 200}, "--max-tokens", 1),
({"enable_prefix_caching": False}, "--no-enable-prefix-caching", 0),
],
)
def test_replace_existing_parameter(self, input_dict, expected_key, key_idx_offset):
"""Test that existing parameters in cmd are replaced."""
item = ParameterSweepItem.from_record(input_dict)
if key_idx_offset == 1:
# Key-value pair
cmd = item.apply_to_cmd(["vllm", "serve", "--max-tokens", "100", "model"])
assert expected_key in cmd
idx = cmd.index(expected_key)
assert cmd[idx + 1] == "200"
assert "100" not in cmd
else:
# Boolean flag
cmd = item.apply_to_cmd(
["vllm", "serve", "--enable-prefix-caching", "model"]
)
assert expected_key in cmd
assert "--enable-prefix-caching" not in cmd
class TestParameterSweep:
"""Test ParameterSweep functionality."""
def test_from_records_list(self):
"""Test creating ParameterSweep from a list of records."""
records = [
{"max_tokens": 100, "temperature": 0.7},
{"max_tokens": 200, "temperature": 0.9},
]
sweep = ParameterSweep.from_records(records)
assert len(sweep) == 2
assert sweep[0]["max_tokens"] == 100
assert sweep[1]["max_tokens"] == 200
def test_read_from_dict(self):
"""Test creating ParameterSweep from a dict format."""
data = {
"experiment1": {"max_tokens": 100, "temperature": 0.7},
"experiment2": {"max_tokens": 200, "temperature": 0.9},
}
sweep = ParameterSweep.read_from_dict(data)
assert len(sweep) == 2
# Check that items have the _benchmark_name field
names = {item["_benchmark_name"] for item in sweep}
assert names == {"experiment1", "experiment2"}
# Check that parameters are preserved
for item in sweep:
if item["_benchmark_name"] == "experiment1":
assert item["max_tokens"] == 100
assert item["temperature"] == 0.7
elif item["_benchmark_name"] == "experiment2":
assert item["max_tokens"] == 200
assert item["temperature"] == 0.9
def test_read_json_list_format(self):
"""Test reading JSON file with list format."""
records = [
{"max_tokens": 100, "temperature": 0.7},
{"max_tokens": 200, "temperature": 0.9},
]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(records, f)
temp_path = Path(f.name)
try:
sweep = ParameterSweep.read_json(temp_path)
assert len(sweep) == 2
assert sweep[0]["max_tokens"] == 100
assert sweep[1]["max_tokens"] == 200
finally:
temp_path.unlink()
def test_read_json_dict_format(self):
"""Test reading JSON file with dict format."""
data = {
"experiment1": {"max_tokens": 100, "temperature": 0.7},
"experiment2": {"max_tokens": 200, "temperature": 0.9},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(data, f)
temp_path = Path(f.name)
try:
sweep = ParameterSweep.read_json(temp_path)
assert len(sweep) == 2
# Check that items have the _benchmark_name field
names = {item["_benchmark_name"] for item in sweep}
assert names == {"experiment1", "experiment2"}
finally:
temp_path.unlink()
def test_unique_benchmark_names_validation(self):
"""Test that duplicate _benchmark_name values raise an error."""
# Test with duplicate names in list format
records = [
{"_benchmark_name": "exp1", "max_tokens": 100},
{"_benchmark_name": "exp1", "max_tokens": 200},
]
with pytest.raises(ValueError, match="Duplicate _benchmark_name values"):
ParameterSweep.from_records(records)
def test_unique_benchmark_names_multiple_duplicates(self):
"""Test validation with multiple duplicate names."""
records = [
{"_benchmark_name": "exp1", "max_tokens": 100},
{"_benchmark_name": "exp1", "max_tokens": 200},
{"_benchmark_name": "exp2", "max_tokens": 300},
{"_benchmark_name": "exp2", "max_tokens": 400},
]
with pytest.raises(ValueError, match="Duplicate _benchmark_name values"):
ParameterSweep.from_records(records)
def test_no_benchmark_names_allowed(self):
"""Test that records without _benchmark_name are allowed."""
records = [
{"max_tokens": 100, "temperature": 0.7},
{"max_tokens": 200, "temperature": 0.9},
]
sweep = ParameterSweep.from_records(records)
assert len(sweep) == 2
def test_mixed_benchmark_names_allowed(self):
"""Test that mixing records with and without _benchmark_name is allowed."""
records = [
{"_benchmark_name": "exp1", "max_tokens": 100},
{"max_tokens": 200, "temperature": 0.9},
]
sweep = ParameterSweep.from_records(records)
assert len(sweep) == 2
class TestParameterSweepItemKeyNormalization:
"""Test key normalization in ParameterSweepItem."""
def test_underscore_to_hyphen_conversion(self):
"""Test that underscores are converted to hyphens in CLI."""
item = ParameterSweepItem.from_record({"max_tokens": 100})
cmd = item.apply_to_cmd(["vllm", "serve"])
assert "--max-tokens" in cmd
def test_nested_key_preserves_suffix(self):
"""Test that nested keys preserve the suffix format."""
# The suffix after the dot should preserve underscores
item = ParameterSweepItem.from_record(
{"compilation_config.some_nested_param": "value"}
)
cmd = item.apply_to_cmd(["vllm", "serve"])
# The prefix (compilation_config) gets converted to hyphens,
# but the suffix (some_nested_param) is preserved
assert any("compilation-config.some_nested_param" in arg for arg in cmd)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pandas as pd
import pytest
from vllm.benchmarks.sweep.plot import (
PlotEqualTo,
PlotFilterBase,
PlotFilters,
PlotGreaterThan,
PlotGreaterThanOrEqualTo,
PlotLessThan,
PlotLessThanOrEqualTo,
PlotNotEqualTo,
)
class TestPlotFilters:
"""Test PlotFilter functionality including 'inf' edge case."""
def setup_method(self):
"""Create sample DataFrames for testing."""
# DataFrame with numeric values
self.df_numeric = pd.DataFrame(
{
"request_rate": [1.0, 5.0, 10.0, 50.0, 100.0],
"value": [10, 20, 30, 40, 50],
}
)
# DataFrame with float('inf') - note: string "inf" values are coerced
# to float when loading data, so we only test with float('inf')
self.df_inf_float = pd.DataFrame(
{
"request_rate": [1.0, 5.0, 10.0, float("inf"), float("inf")],
"value": [10, 20, 30, 40, 50],
}
)
@pytest.mark.parametrize(
"target,expected_count",
[
("5.0", 1),
("10.0", 1),
("1.0", 1),
],
)
def test_equal_to_numeric(self, target, expected_count):
"""Test PlotEqualTo with numeric values."""
filter_obj = PlotEqualTo("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
def test_equal_to_inf_float(self):
"""Test PlotEqualTo with float('inf')."""
filter_obj = PlotEqualTo("request_rate", "inf")
result = filter_obj.apply(self.df_inf_float)
# Should match both float('inf') entries because float('inf') == float('inf')
assert len(result) == 2
@pytest.mark.parametrize(
"target,expected_count",
[
("5.0", 4), # All except 5.0
("1.0", 4), # All except 1.0
],
)
def test_not_equal_to_numeric(self, target, expected_count):
"""Test PlotNotEqualTo with numeric values."""
filter_obj = PlotNotEqualTo("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
def test_not_equal_to_inf_float(self):
"""Test PlotNotEqualTo with float('inf')."""
filter_obj = PlotNotEqualTo("request_rate", "inf")
result = filter_obj.apply(self.df_inf_float)
# Should exclude float('inf') entries
assert len(result) == 3
@pytest.mark.parametrize(
"target,expected_count",
[
("10.0", 2), # 1.0, 5.0
("50.0", 3), # 1.0, 5.0, 10.0
("5.0", 1), # 1.0
],
)
def test_less_than(self, target, expected_count):
"""Test PlotLessThan with numeric values."""
filter_obj = PlotLessThan("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
@pytest.mark.parametrize(
"target,expected_count",
[
("10.0", 3), # 1.0, 5.0, 10.0
("5.0", 2), # 1.0, 5.0
],
)
def test_less_than_or_equal_to(self, target, expected_count):
"""Test PlotLessThanOrEqualTo with numeric values."""
filter_obj = PlotLessThanOrEqualTo("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
@pytest.mark.parametrize(
"target,expected_count",
[
("10.0", 2), # 50.0, 100.0
("5.0", 3), # 10.0, 50.0, 100.0
],
)
def test_greater_than(self, target, expected_count):
"""Test PlotGreaterThan with numeric values."""
filter_obj = PlotGreaterThan("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
@pytest.mark.parametrize(
"target,expected_count",
[
("10.0", 3), # 10.0, 50.0, 100.0
("5.0", 4), # 5.0, 10.0, 50.0, 100.0
],
)
def test_greater_than_or_equal_to(self, target, expected_count):
"""Test PlotGreaterThanOrEqualTo with numeric values."""
filter_obj = PlotGreaterThanOrEqualTo("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
@pytest.mark.parametrize(
"filter_str,expected_var,expected_target,expected_type",
[
("request_rate==5.0", "request_rate", "5.0", PlotEqualTo),
("request_rate!=10.0", "request_rate", "10.0", PlotNotEqualTo),
("request_rate<50.0", "request_rate", "50.0", PlotLessThan),
("request_rate<=50.0", "request_rate", "50.0", PlotLessThanOrEqualTo),
("request_rate>10.0", "request_rate", "10.0", PlotGreaterThan),
("request_rate>=10.0", "request_rate", "10.0", PlotGreaterThanOrEqualTo),
("request_rate==inf", "request_rate", "inf", PlotEqualTo),
("request_rate!='inf'", "request_rate", "inf", PlotNotEqualTo),
],
)
def test_parse_str(self, filter_str, expected_var, expected_target, expected_type):
"""Test parsing filter strings."""
filter_obj = PlotFilterBase.parse_str(filter_str)
assert isinstance(filter_obj, expected_type)
assert filter_obj.var == expected_var
assert filter_obj.target == expected_target
def test_parse_str_inf_edge_case(self):
"""Test parsing 'inf' string in filter."""
filter_obj = PlotFilterBase.parse_str("request_rate==inf")
assert isinstance(filter_obj, PlotEqualTo)
assert filter_obj.var == "request_rate"
assert filter_obj.target == "inf"
def test_parse_multiple_filters(self):
"""Test parsing multiple filters."""
filters = PlotFilters.parse_str("request_rate>5.0,value<=40")
assert len(filters) == 2
assert isinstance(filters[0], PlotGreaterThan)
assert isinstance(filters[1], PlotLessThanOrEqualTo)
def test_parse_empty_filter(self):
"""Test parsing empty filter string."""
filters = PlotFilters.parse_str("")
assert len(filters) == 0
...@@ -298,10 +298,14 @@ def test_tp2_attn_quant_allreduce_rmsnorm( ...@@ -298,10 +298,14 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text, log_holder.text,
) )
assert len(log_matches) == 2, log_holder.text # 2 for each compile range
# (global compile range can be split due to fuse_allreduce_rmsnorm)
num_compile_ranges = len(compilation_config.get_compile_ranges())
assert num_compile_ranges in [1, 2]
assert int(log_matches[0]) == matches.attention_fusion assert len(log_matches) == 2 * num_compile_ranges, log_holder.text
assert int(log_matches[1]) == matches.attention_fusion
assert all(int(log_match) == matches.attention_fusion for log_match in log_matches)
log_matches = re.findall( log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns", r"collective_fusion.py:\d+] Replaced (\d+) patterns",
...@@ -312,6 +316,12 @@ def test_tp2_attn_quant_allreduce_rmsnorm( ...@@ -312,6 +316,12 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
assert int(log_matches[0]) == matches.allreduce_fusion assert int(log_matches[0]) == matches.allreduce_fusion
assert int(log_matches[1]) == matches.allreduce_fusion assert int(log_matches[1]) == matches.allreduce_fusion
log_matches = re.findall(
r"pass_manager.py:\d+] Skipping .*AllReduceFusionPass.* with compile range",
log_holder.text,
)
assert len(log_matches) == 2 * (num_compile_ranges - 1), log_holder.text
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -446,7 +456,6 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg ...@@ -446,7 +456,6 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
# No cudagraphs by default # No cudagraphs by default
if compilation_config.cudagraph_mode is None: if compilation_config.cudagraph_mode is None:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.cudagraph_mode = CUDAGraphMode.NONE
llm = LLM( llm = LLM(
model=model, model=model,
compilation_config=compilation_config, compilation_config=compilation_config,
...@@ -459,3 +468,9 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg ...@@ -459,3 +468,9 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Get the compile ranges split points after vllm config post init
# in order to compute compile ranges correctly
compilation_config.compile_ranges_split_points = (
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import multiprocessing
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
...@@ -137,3 +139,67 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): ...@@ -137,3 +139,67 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
artifacts = compiled_mod.aot_compiled_fn._artifacts artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards() guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
@use_vllm_config(make_vllm_config())
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
"""
Test that compiling gpt2 twice results in a cache hit and
capture torch dynamic symbol creations to ensure make_symbol
not called on cache hit.
"""
import torch.fx.experimental.symbolic_shapes as symbolic_shapes_module
from torch.utils._sympy.symbol import make_symbol
from vllm import LLM
create_symbol_counter = multiprocessing.Value("i", 0)
original_make_symbol = make_symbol
@functools.wraps(original_make_symbol)
def counting_make_symbol(prefix, idx, **kwargs):
with create_symbol_counter.get_lock():
create_symbol_counter.value += 1
return original_make_symbol(prefix, idx, **kwargs)
symbolic_shapes_module.make_symbol = counting_make_symbol
try:
with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
# First compilation - initialize model and generate
llm_model = LLM(
model="gpt2",
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
max_model_len=256,
)
llm_model.generate("Hello, my name is")
assert create_symbol_counter.value == 2
create_symbol_counter.value = 0
# Clean up first model
del llm_model
# Second compilation - should hit cache
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
llm_model = LLM(
model="gpt2",
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
max_model_len=256,
)
llm_model.generate("Hello, my name is")
assert create_symbol_counter.value == 0
finally:
# Restore original method
symbolic_shapes_module.make_symbol = original_make_symbol
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from torch import fx as fx
from torch import nn
# This import automatically registers `torch.ops.silly.attention`
import tests.compile.silly_attention # noqa
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.inductor_pass import (
InductorPass,
get_pass_context,
)
from vllm.config import (
VllmConfig,
set_current_vllm_config,
)
from vllm.config.compilation import CompilationConfig, CompilationMode
from vllm.config.scheduler import SchedulerConfig
from vllm.config.utils import Range
from vllm.forward_context import set_forward_context
BATCH_SIZE = 64
MLP_SIZE = 128
@support_torch_compile
class TestModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = x * 3
return x
@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]):
with set_forward_context({}, vllm_config=vllm_config):
model(torch.randn(BATCH_SIZE, MLP_SIZE))
for batch_size in batch_sizes:
model(torch.randn(batch_size, MLP_SIZE))
class PostGradRangeChecker(InductorPass):
def __init__(self, ranges: list[Range]):
self.ranges = ranges
self.num_calls = 0
def __call__(self, graph: fx.Graph):
compile_range = get_pass_context().compile_range
assert compile_range in self.ranges, (
f"Compile range {compile_range} not in {self.ranges}"
)
self.num_calls += 1
def uuid(self) -> str:
state: dict[str, Any] = {}
return InductorPass.hash_dict(state)
def test_compile_ranges(use_fresh_inductor_cache):
post_grad_range_checker = PostGradRangeChecker(
[
Range(start=1, end=8),
Range(start=16, end=16),
Range(start=9, end=32),
Range(start=64, end=64),
Range(start=33, end=8192),
]
)
torch.set_default_device("cuda")
vllm_config = VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8, 32],
compile_sizes=[16, 64, 128],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker,
},
),
)
with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix="").eval()
# Number of compilations: 3 for each compile range + 2 compile sizes
batch_sizes = [1, 4, 16, 24, 48, 64, 8192]
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=1,
num_backend_compilations=5,
):
run_model(vllm_config, model, batch_sizes)
assert post_grad_range_checker.num_calls == 5
def test_compile_config_get_compile_ranges():
compilation_config = CompilationConfig(
compile_ranges_split_points=[8, 32],
)
VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
),
compilation_config=compilation_config,
)
assert compilation_config.get_compile_ranges() == [
Range(start=1, end=8),
Range(start=9, end=32),
Range(start=33, end=8192),
]
def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
# To force multiple compilations, we disable the compile cache
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
post_grad_range_checker = PostGradRangeChecker(
ranges=[
Range(start=1, end=8),
Range(start=9, end=8192),
]
)
scheduler_config = SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
)
torch.set_default_device("cuda")
def create_vllm_config():
return VllmConfig(
scheduler_config=scheduler_config,
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker,
},
),
)
vllm_config_1 = create_vllm_config()
with set_current_vllm_config(vllm_config_1):
model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval()
batch_sizes = [1, 16]
run_model(vllm_config_1, model1, batch_sizes)
assert post_grad_range_checker.num_calls == 2
post_grad_range_checker.num_calls = 0
# Create a new vllm config with the new pass context
vllm_config_2 = create_vllm_config()
with set_current_vllm_config(vllm_config_2):
model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval()
batch_sizes = [4, 32]
run_model(vllm_config_2, model2, batch_sizes)
# Check that cache is used, so the number of calls
# should be 0
assert post_grad_range_checker.num_calls == 0
...@@ -10,7 +10,7 @@ from pydantic import ValidationError ...@@ -10,7 +10,7 @@ from pydantic import ValidationError
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.config.compilation import CompilationMode, PassConfig from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.logger import _print_warning_once from vllm.logger import _print_warning_once
...@@ -235,6 +235,70 @@ def test_splitting_ops_dynamic(): ...@@ -235,6 +235,70 @@ def test_splitting_ops_dynamic():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_moe_splitting_ops_deepep_ht_piecewise():
# Non-inductor, non-attn-fusion case: DeepEP HT with dp>1
# should add MoE ops to splitting_ops on top of attention ops.
config = VllmConfig(
parallel_config=ParallelConfig(
all2all_backend="deepep_high_throughput",
data_parallel_size=8,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
)
splitting_ops = config.compilation_config.splitting_ops
assert splitting_ops is not None
assert "vllm::moe_forward" in splitting_ops
assert "vllm::moe_forward_shared" in splitting_ops
def test_moe_splitting_ops_deepep_ht_inductor_partition():
# Inductor partition case: user-provided splitting_ops should be
# preserved and MoE ops should be appended for DeepEP HT with dp>1.
config = VllmConfig(
parallel_config=ParallelConfig(
all2all_backend="deepep_high_throughput",
data_parallel_size=8,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=[
"vllm::unified_attention",
"vllm::moe_forward",
"vllm::moe_forward_shared",
],
),
)
splitting_ops = config.compilation_config.splitting_ops
assert splitting_ops == [
"vllm::unified_attention",
"vllm::moe_forward",
"vllm::moe_forward_shared",
]
def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
# Pure attn-fusion case without inductor partition: even with
# DeepEP HT and dp>1, we should not re-enable piecewise compilation
# or add MoE ops into splitting_ops.
config = VllmConfig(
parallel_config=ParallelConfig(
all2all_backend="deepep_high_throughput",
data_parallel_size=8,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
),
)
assert config.compilation_config.splitting_ops == []
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
def test_should_split(): def test_should_split():
import torch import torch
...@@ -392,39 +456,48 @@ def test_pass_config_deprecation(caplog_vllm): ...@@ -392,39 +456,48 @@ def test_pass_config_deprecation(caplog_vllm):
assert "enable_fusion is deprecated" in caplog_vllm.text assert "enable_fusion is deprecated" in caplog_vllm.text
assert config.fuse_norm_quant is True assert config.fuse_norm_quant is True
assert config.fuse_act_quant is True assert config.fuse_act_quant is True
assert config.enable_fusion is None assert config.enable_fusion is True
# Test enable_attn_fusion -> fuse_attn_quant # Test enable_attn_fusion -> fuse_attn_quant
caplog_vllm.clear() caplog_vllm.clear()
config = PassConfig(enable_attn_fusion=True) config = PassConfig(enable_attn_fusion=True)
assert "enable_attn_fusion is deprecated" in caplog_vllm.text assert "enable_attn_fusion is deprecated" in caplog_vllm.text
assert config.fuse_attn_quant is True assert config.fuse_attn_quant is True
assert config.enable_attn_fusion is None assert config.enable_attn_fusion is True
# Test enable_noop -> eliminate_noops # Test enable_noop -> eliminate_noops
caplog_vllm.clear() caplog_vllm.clear()
config = PassConfig(enable_noop=True) config = PassConfig(enable_noop=True)
assert "enable_noop is deprecated" in caplog_vllm.text assert "enable_noop is deprecated" in caplog_vllm.text
assert config.eliminate_noops is True assert config.eliminate_noops is True
assert config.enable_noop is None assert config.enable_noop is True
# Test enable_sequence_parallelism -> enable_sp # Test enable_sequence_parallelism -> enable_sp
caplog_vllm.clear() caplog_vllm.clear()
config = PassConfig(enable_sequence_parallelism=True) config = PassConfig(enable_sequence_parallelism=True)
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
assert config.enable_sp is True assert config.enable_sp is True
assert config.enable_sequence_parallelism is None assert config.enable_sequence_parallelism is True
# Test enable_async_tp -> fuse_gemm_comms # Test enable_async_tp -> fuse_gemm_comms
caplog_vllm.clear() caplog_vllm.clear()
config = PassConfig(enable_async_tp=True) config = PassConfig(enable_async_tp=True)
assert "enable_async_tp is deprecated" in caplog_vllm.text assert "enable_async_tp is deprecated" in caplog_vllm.text
assert config.fuse_gemm_comms is True assert config.fuse_gemm_comms is True
assert config.enable_async_tp is None assert config.enable_async_tp is True
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms # Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
caplog_vllm.clear() caplog_vllm.clear()
config = PassConfig(enable_fi_allreduce_fusion=True) config = PassConfig(enable_fi_allreduce_fusion=True)
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
assert config.fuse_allreduce_rms is True assert config.fuse_allreduce_rms is True
assert config.enable_fi_allreduce_fusion is None assert config.enable_fi_allreduce_fusion is True
# Test hash consistency
config_old = PassConfig(enable_fusion=True)
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
assert config_old.compute_hash() == config_new.compute_hash()
config_old = PassConfig(enable_async_tp=True)
config_new = PassConfig(fuse_gemm_comms=True)
assert config_old.compute_hash() == config_new.compute_hash()
...@@ -2,12 +2,21 @@ ...@@ -2,12 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc import gc
import tempfile
from contextlib import contextmanager
import pytest import pytest
import torch import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config.compilation import CompilationMode, DynamicShapesType from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.config.compilation import (
CompilationMode,
DynamicShapesConfig,
DynamicShapesType,
)
from vllm.forward_context import set_forward_context
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
...@@ -29,18 +38,19 @@ def get_test_models(): ...@@ -29,18 +38,19 @@ def get_test_models():
) )
@pytest.mark.parametrize("use_aot_compile", ["0"]) @pytest.mark.parametrize("use_aot_compile", ["0"])
@pytest.mark.parametrize("use_bytecode_hook", [True, False]) @pytest.mark.parametrize("use_bytecode_hook", [True, False])
@pytest.mark.parametrize("evaluate_guards", [False, True])
@pytest.mark.skipif( @pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
) )
def test_dynamic_shapes_compilation( def test_dynamic_shapes_compilation(
monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook monkeypatch,
model_name,
shapes_type,
use_aot_compile,
use_bytecode_hook,
evaluate_guards,
): ):
"""Test that all dynamic shapes types compile successfully""" """Test that all dynamic shapes types compile successfully"""
print(
f"\nTesting model: {model_name} with {shapes_type.name}, "
f"AOT compile: {use_aot_compile}, "
f"Bytecode hook: {use_bytecode_hook}"
)
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED: if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0") pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
...@@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation( ...@@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation(
"mode": CompilationMode.VLLM_COMPILE, "mode": CompilationMode.VLLM_COMPILE,
"dynamic_shapes_config": { "dynamic_shapes_config": {
"type": shapes_type.value, "type": shapes_type.value,
"evaluate_guards": evaluate_guards,
}, },
}, },
) )
...@@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation( ...@@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation(
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
print("GPU memory cleared") print("GPU memory cleared")
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
@pytest.mark.parametrize(
"dynamic_shapes_type",
[
DynamicShapesType.BACKED,
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
],
)
@pytest.mark.parametrize("evaluate_guards", [False, True])
def test_model_specialization_with_evaluate_guards(
monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards
):
"""Test that evaluate_guards correctly detects shape specialization
violations.
"""
if (
use_aot_compile == "1"
and dynamic_shapes_type == DynamicShapesType.BACKED
and evaluate_guards
):
pytest.skip("evaluate_guards for backed does not work with aot_compile =1")
@support_torch_compile
class ModelWithSizeCheck(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x: torch.Tensor):
# This will cause specialization - torch.compile will guard on
# sx.shape[0]
if x.shape[0] >= 10:
return x * 10
else:
return x * 10
@support_torch_compile
class ModelWithOneSizeCheck(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x: torch.Tensor):
# This will cause 0/1 specializations.
if x.shape[0] == 0:
return x * 10
if x.shape[0] == 1:
return x * 10
else:
return x * 10
@contextmanager
def use_vllm_config(vllm_config: VllmConfig):
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
yield
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "0")
# Create vllm config with the desired settings
from vllm.config import CompilationMode
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
dynamic_shapes_config=DynamicShapesConfig(
type=dynamic_shapes_type,
evaluate_guards=evaluate_guards,
),
)
)
def test(model_class, input1, input2, is_01_specialization=False):
with (
torch.no_grad(),
use_vllm_config(vllm_config),
tempfile.TemporaryDirectory() as tmpdirname,
):
monkeypatch.setenv("VLLM_CACHE_ROOT", tmpdirname)
model = model_class(vllm_config=vllm_config).cuda()
model(input1)
if evaluate_guards and (
not (
is_01_specialization
and dynamic_shapes_type == DynamicShapesType.BACKED
)
):
# This should fail because guards were added.
with pytest.raises(RuntimeError) as excinfo:
model(input2)
# Expected failure - guard was violated
error_msg = str(excinfo.value)
assert (
"GuardManager check failed" in error_msg
or "Detected recompile when torch.compile stance" in error_msg
), error_msg
else:
model(input2)
test(ModelWithSizeCheck, torch.randn(20, 10).cuda(), torch.randn(5, 10).cuda())
test(ModelWithSizeCheck, torch.randn(5, 10).cuda(), torch.randn(20, 10).cuda())
test(
ModelWithOneSizeCheck,
torch.randn(20, 10).cuda(),
torch.randn(1, 10).cuda(),
is_01_specialization=True,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest import pytest
import torch import torch
import vllm.plugins import vllm.plugins
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.matcher_utils import QUANT_OPS
...@@ -18,6 +21,9 @@ from vllm.config import ( ...@@ -18,6 +21,9 @@ from vllm.config import (
VllmConfig, VllmConfig,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
QuantKey, QuantKey,
...@@ -25,10 +31,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -25,10 +31,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, Fp8LinearOp,
cutlass_block_fp8_supported,
cutlass_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity, maybe_create_device_identity,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from ..utils import override_cutlass_fp8_supported from ..utils import override_cutlass_fp8_supported
from .backend import TestBackend from .backend import TestBackend
...@@ -44,7 +52,7 @@ class TestModel(torch.nn.Module): ...@@ -44,7 +52,7 @@ class TestModel(torch.nn.Module):
self, self,
hidden_size: int, hidden_size: int,
eps: float, eps: float,
static: bool, group_shape: GroupShape,
cuda_force_torch: bool, cuda_force_torch: bool,
*args, *args,
**kwargs, **kwargs,
...@@ -52,8 +60,17 @@ class TestModel(torch.nn.Module): ...@@ -52,8 +60,17 @@ class TestModel(torch.nn.Module):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cuda_force_torch = cuda_force_torch self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] if group_shape.is_per_group():
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN self.wscale = [
torch.rand(
(hidden_size // group_shape[1], hidden_size // group_shape[1]),
dtype=torch.float32,
)
for _ in range(3)
]
else:
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
static = group_shape == GroupShape.PER_TENSOR
quant_scale = ScaleDesc(torch.float32, static, group_shape) quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static: if static:
...@@ -61,18 +78,29 @@ class TestModel(torch.nn.Module): ...@@ -61,18 +78,29 @@ class TestModel(torch.nn.Module):
else: else:
self.scale = [None for _ in range(3)] self.scale = [None for _ in range(3)]
self.w = [ self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
for _ in range(3)
] ]
if not group_shape.is_per_group():
self.w = [self.w[0].t() for _ in range(3)]
with override_cutlass_fp8_supported(not cuda_force_torch): if group_shape.is_per_group():
self.fp8_linear = Fp8LinearOp( self.fp8_linear = W8A8BlockFp8LinearOp(
act_quant_static=static, weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape, act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=False,
) )
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
else:
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=static,
act_quant_group_shape=group_shape,
)
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
self.enable_rms_norm_custom_op = self.norm[0].enabled() self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() self.group_shape = group_shape
def forward(self, x): def forward(self, x):
# avoid having graph input be an arg to a pattern directly # avoid having graph input be an arg to a pattern directly
...@@ -119,13 +147,87 @@ class TestModel(torch.nn.Module): ...@@ -119,13 +147,87 @@ class TestModel(torch.nn.Module):
) )
GROUP_SHAPES = [
GroupShape.PER_TOKEN,
GroupShape.PER_TENSOR,
GroupShape(1, 128),
GroupShape(1, 64),
]
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, **kwargs):
super().__init__()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
]
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = [
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
for _ in range(3)
]
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
self.eps = eps
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
# make sure resid is used for replacement to work
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
x2, resid, self.norm_weight[1], self.eps
)
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
x3, resid, self.norm_weight[2], self.eps
)
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
x4, resid, self.norm_weight[3], self.eps
)
return y4
def ops_in_model_before(self):
return [
torch.ops.vllm.rocm_aiter_rms_norm,
torch.ops.vllm.rocm_aiter_group_fp8_quant,
]
def ops_in_model_before_partial(self):
return []
def ops_in_model_after(self):
return [
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("group_shape", GROUP_SHAPES)
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.parametrize(
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) "model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
list(itertools.product([TestModel], [True, False], [True, False]))
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
)
# cuda_force_torch used to test torch code path on platforms that # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True. # cutlass_fp8_supported() == True.
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -139,16 +241,29 @@ def test_fusion_rmsnorm_quant( ...@@ -139,16 +241,29 @@ def test_fusion_rmsnorm_quant(
hidden_size, hidden_size,
num_tokens, num_tokens,
eps, eps,
static, group_shape,
model_class,
enable_rms_norm_custom_op, enable_rms_norm_custom_op,
enable_quant_fp8_custom_op, enable_quant_fp8_custom_op,
cuda_force_torch, cuda_force_torch,
): ):
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
# Skip test for 64-bit group shape when running with cutlass or deepgemm
if group_shape == GroupShape(1, 64) and (
cutlass_block_fp8_supported() or is_deep_gemm_supported()
):
pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")
custom_ops = [] custom_ops = []
if enable_rms_norm_custom_op: if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm") custom_ops.append("+rms_norm")
...@@ -167,13 +282,24 @@ def test_fusion_rmsnorm_quant( ...@@ -167,13 +282,24 @@ def test_fusion_rmsnorm_quant(
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work # Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config) if model_class is TestRmsnormGroupFp8QuantModel:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
)
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
else:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass) backend2 = TestBackend(noop_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch) model = model_class(
hidden_size=hidden_size,
eps=eps,
group_shape=group_shape,
cuda_force_torch=cuda_force_torch,
)
# First dimension dynamic # First dimension dynamic
x = torch.rand(num_tokens, hidden_size) x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(x, 0)
...@@ -202,7 +328,10 @@ def test_fusion_rmsnorm_quant( ...@@ -202,7 +328,10 @@ def test_fusion_rmsnorm_quant(
# there's a risk that the fused add doesn't get included in the # there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant. # replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add). # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op: if (
not enable_rms_norm_custom_op
and model_class is not TestRmsnormGroupFp8QuantModel
):
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7 assert n_add_nodes(backend.graph_pre_pass) == 7
......
...@@ -12,13 +12,13 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant ...@@ -12,13 +12,13 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.matcher_utils import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
CompilationConfig, CompilationConfig,
CompilationMode, CompilationMode,
...@@ -335,6 +335,7 @@ def test_attention_quant_pattern( ...@@ -335,6 +335,7 @@ def test_attention_quant_pattern(
custom_ops=custom_ops_list, custom_ops=custom_ops_list,
), ),
cache_config=CacheConfig(cache_dtype="fp8"), cache_config=CacheConfig(cache_dtype="fp8"),
attention_config=AttentionConfig(backend=backend),
) )
# Create test inputs # Create test inputs
...@@ -352,7 +353,6 @@ def test_attention_quant_pattern( ...@@ -352,7 +353,6 @@ def test_attention_quant_pattern(
with ( with (
set_current_vllm_config(vllm_config_unfused), set_current_vllm_config(vllm_config_unfused),
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused), set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
global_force_attn_backend_context_manager(backend),
): ):
model_unfused = model_class( model_unfused = model_class(
num_qo_heads=num_qo_heads, num_qo_heads=num_qo_heads,
...@@ -378,7 +378,6 @@ def test_attention_quant_pattern( ...@@ -378,7 +378,6 @@ def test_attention_quant_pattern(
with ( with (
set_current_vllm_config(vllm_config), set_current_vllm_config(vllm_config),
set_forward_context(attn_metadata=None, vllm_config=vllm_config), set_forward_context(attn_metadata=None, vllm_config=vllm_config),
global_force_attn_backend_context_manager(backend),
): ):
model_fused = model_class( model_fused = model_class(
num_qo_heads=num_qo_heads, num_qo_heads=num_qo_heads,
......
...@@ -5,9 +5,14 @@ import copy ...@@ -5,9 +5,14 @@ import copy
import pytest import pytest
import torch import torch
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.inductor_pass import (
CallableInductorPass,
InductorPass,
pass_context,
)
from vllm.compilation.pass_manager import PostGradPassManager from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.config.utils import Range
# dummy custom pass that doesn't inherit # dummy custom pass that doesn't inherit
...@@ -42,35 +47,37 @@ class ProperPass(InductorPass): ...@@ -42,35 +47,37 @@ class ProperPass(InductorPass):
], ],
) )
def test_pass_manager_uuid(callable): def test_pass_manager_uuid(callable):
# Some passes need dtype to be set # Set the pass context as PassManager uuid uses it
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16)) with pass_context(Range(start=1, end=8)):
# Some passes need dtype to be set
pass_manager = PostGradPassManager() config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
pass_manager.configure(config)
pass_manager = PostGradPassManager()
# Check that UUID is different if the same pass is added 2x pass_manager.configure(config)
pass_manager.add(callable)
uuid1 = pass_manager.uuid() # Check that UUID is different if the same pass is added 2x
pass_manager.add(callable) pass_manager.add(callable)
uuid2 = pass_manager.uuid() uuid1 = pass_manager.uuid()
assert uuid1 != uuid2 pass_manager.add(callable)
uuid2 = pass_manager.uuid()
# UUID should be the same as the original one, assert uuid1 != uuid2
# as we constructed in the same way.
pass_manager2 = PostGradPassManager() # UUID should be the same as the original one,
pass_manager2.configure(config) # as we constructed in the same way.
pass_manager2.add(callable) pass_manager2 = PostGradPassManager()
assert uuid1 == pass_manager2.uuid() pass_manager2.configure(config)
pass_manager2.add(callable)
# UUID should be different due to config change assert uuid1 == pass_manager2.uuid()
config2 = copy.deepcopy(config)
config2.compilation_config.pass_config.fuse_norm_quant = ( # UUID should be different due to config change
not config2.compilation_config.pass_config.fuse_norm_quant config2 = copy.deepcopy(config)
) config2.compilation_config.pass_config.fuse_norm_quant = (
config2.compilation_config.pass_config.fuse_act_quant = ( not config2.compilation_config.pass_config.fuse_norm_quant
not config2.compilation_config.pass_config.fuse_act_quant )
) config2.compilation_config.pass_config.fuse_act_quant = (
pass_manager3 = PostGradPassManager() not config2.compilation_config.pass_config.fuse_act_quant
pass_manager3.configure(config2) )
pass_manager3.add(callable) pass_manager3 = PostGradPassManager()
assert uuid1 != pass_manager3.uuid() pass_manager3.configure(config2)
pass_manager3.add(callable)
assert uuid1 != pass_manager3.uuid()
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
from vllm._aiter_ops import IS_AITER_FOUND
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.activation_quant_fusion import ( from vllm.compilation.activation_quant_fusion import (
FUSED_OPS, FUSED_OPS,
...@@ -24,6 +25,7 @@ from vllm.config import ( ...@@ -24,6 +25,7 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
kFp8StaticTensorSym, kFp8StaticTensorSym,
...@@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): ...@@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
return [FUSED_OPS[kNvfp4Quant]] return [FUSED_OPS[kNvfp4Quant]]
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = torch.rand(
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
return x2
def ops_in_model_before(self):
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
]
def ops_in_model_after(self):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
@pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
...@@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): ...@@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_class, enable_quant_fp8_custom_op, cuda_force_torch", "model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])) list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
+ [(TestSiluMulNvfp4QuantModel, False, False)], + [
(TestSiluMulNvfp4QuantModel, False, False),
(TestSiluMulGroupFp8QuantModel, False, False),
],
) )
# cuda_force_torch used to test torch code path on platforms that # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True. # cutlass_fp8_supported() == True.
...@@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant( ...@@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], model_class: type[
TestSiluMulFp8QuantModel
| TestSiluMulNvfp4QuantModel
| TestSiluMulGroupFp8QuantModel
],
enable_silu_mul_custom_op: bool, enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool, enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool, cuda_force_torch: bool,
): ):
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.") pytest.skip("NVFP4 is not supported on this GPU.")
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
...@@ -173,9 +217,15 @@ def test_fusion_silu_and_mul_quant( ...@@ -173,9 +217,15 @@ def test_fusion_silu_and_mul_quant(
) )
with set_current_vllm_config(config): with set_current_vllm_config(config):
fusion_pass = ActivationQuantFusionPass(config) fusion_passes = [ActivationQuantFusionPass(config)]
if IS_AITER_FOUND:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes) backend = TestBackend(*passes)
model = model_class( model = model_class(
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
...@@ -194,12 +244,14 @@ def test_fusion_silu_and_mul_quant( ...@@ -194,12 +244,14 @@ def test_fusion_silu_and_mul_quant(
atol, rtol = 1e-3, 1e-3 atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel: elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1 atol, rtol = 1e-1, 1e-1
elif model_class == TestSiluMulGroupFp8QuantModel:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close( torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
) )
assert fusion_pass.matched_count == 1 assert sum([p.matched_count for p in fusion_passes]) == 1
# In pre-nodes, quant op should be present and fused kernels should not # In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before()) backend.check_before_ops(model.ops_in_model_before())
......
...@@ -27,7 +27,7 @@ import threading ...@@ -27,7 +27,7 @@ import threading
from collections.abc import Generator from collections.abc import Generator
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import Any, Callable, TypedDict, TypeVar, cast from typing import Any, Callable, TypedDict, TypeVar, cast, TYPE_CHECKING
import numpy as np import numpy as np
import pytest import pytest
...@@ -59,6 +59,7 @@ from vllm.distributed import ( ...@@ -59,6 +59,7 @@ from vllm.distributed import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
...@@ -66,6 +67,14 @@ from vllm.transformers_utils.utils import maybe_model_redirect ...@@ -66,6 +67,14 @@ from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from torch._inductor.utils import fresh_cache
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.generation.utils import GenerateOutput
logger = init_logger(__name__) logger = init_logger(__name__)
_TEST_DIR = os.path.dirname(__file__) _TEST_DIR = os.path.dirname(__file__)
...@@ -201,10 +210,7 @@ def dynamo_reset(): ...@@ -201,10 +210,7 @@ def dynamo_reset():
@pytest.fixture @pytest.fixture
def example_prompts() -> list[str]: def example_prompts() -> list[str]:
prompts = [] return [prompt for filename in _TEST_PROMPTS for prompt in _read_prompts(filename)]
for filename in _TEST_PROMPTS:
prompts += _read_prompts(filename)
return prompts
@pytest.fixture @pytest.fixture
...@@ -223,10 +229,7 @@ class DecoderPromptType(Enum): ...@@ -223,10 +229,7 @@ class DecoderPromptType(Enum):
@pytest.fixture @pytest.fixture
def example_long_prompts() -> list[str]: def example_long_prompts() -> list[str]:
prompts = [] return [prompt for filename in _LONG_PROMPTS for prompt in _read_prompts(filename)]
for filename in _LONG_PROMPTS:
prompts += _read_prompts(filename)
return prompts
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
...@@ -352,10 +355,13 @@ class HfRunner: ...@@ -352,10 +355,13 @@ class HfRunner:
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
model = auto_cls.from_pretrained( model = cast(
model_name, nn.Module,
trust_remote_code=trust_remote_code, auto_cls.from_pretrained(
**model_kwargs, model_name,
trust_remote_code=trust_remote_code,
**model_kwargs,
),
) )
# in case some unquantized custom models are not in same dtype # in case some unquantized custom models are not in same dtype
...@@ -373,10 +379,12 @@ class HfRunner: ...@@ -373,10 +379,12 @@ class HfRunner:
self.model = model self.model = model
if not skip_tokenizer_init: if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast" = (
model_name, AutoTokenizer.from_pretrained(
dtype=dtype, model_name,
trust_remote_code=trust_remote_code, dtype=dtype,
trust_remote_code=trust_remote_code,
)
) )
# don't put this import at the top level # don't put this import at the top level
...@@ -397,6 +405,7 @@ class HfRunner: ...@@ -397,6 +405,7 @@ class HfRunner:
images: PromptImageInput | None = None, images: PromptImageInput | None = None,
videos: PromptVideoInput | None = None, videos: PromptVideoInput | None = None,
audios: PromptAudioInput | None = None, audios: PromptAudioInput | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]: ) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]:
if images is not None: if images is not None:
assert len(prompts) == len(images) assert len(prompts) == len(images)
...@@ -410,10 +419,18 @@ class HfRunner: ...@@ -410,10 +419,18 @@ class HfRunner:
all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = [] all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = []
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if isinstance(prompt, str): if isinstance(prompt, str):
processor_kwargs: dict[str, Any] = { # Create a copy to avoid modifying the original dict
"text": prompt, processor_kwargs = (
"return_tensors": "pt", tokenization_kwargs.copy()
} if tokenization_kwargs is not None
else {}
)
processor_kwargs.update(
{
"text": prompt,
"return_tensors": "pt",
}
)
if images is not None and (image := images[i]) is not None: if images is not None and (image := images[i]) is not None:
processor_kwargs["images"] = image processor_kwargs["images"] = image
if videos is not None and (video := videos[i]) is not None: if videos is not None and (video := videos[i]) is not None:
...@@ -494,7 +511,7 @@ class HfRunner: ...@@ -494,7 +511,7 @@ class HfRunner:
outputs: list[tuple[list[list[int]], list[str]]] = [] outputs: list[tuple[list[list[int]], list[str]]] = []
for inputs in all_inputs: for inputs in all_inputs:
output_ids = self.model.generate( output_ids: torch.Tensor = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
use_cache=True, use_cache=True,
**kwargs, **kwargs,
...@@ -504,8 +521,7 @@ class HfRunner: ...@@ -504,8 +521,7 @@ class HfRunner:
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
) )
output_ids = output_ids.cpu().tolist() outputs.append((output_ids.cpu().tolist(), output_str))
outputs.append((output_ids, output_str))
return outputs return outputs
def generate_greedy( def generate_greedy(
...@@ -573,7 +589,7 @@ class HfRunner: ...@@ -573,7 +589,7 @@ class HfRunner:
all_logprobs: list[list[torch.Tensor]] = [] all_logprobs: list[list[torch.Tensor]] = []
for inputs in all_inputs: for inputs in all_inputs:
output = self.model.generate( output: "GenerateOutput" = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
use_cache=True, use_cache=True,
do_sample=False, do_sample=False,
...@@ -655,7 +671,7 @@ class HfRunner: ...@@ -655,7 +671,7 @@ class HfRunner:
all_output_strs: list[str] = [] all_output_strs: list[str] = []
for inputs in all_inputs: for inputs in all_inputs:
output = self.model.generate( output: "GenerateOutput" = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
use_cache=True, use_cache=True,
do_sample=False, do_sample=False,
...@@ -1389,7 +1405,11 @@ class LocalAssetServer: ...@@ -1389,7 +1405,11 @@ class LocalAssetServer:
return f"{self.base_url}/{name}" return f"{self.base_url}/{name}"
def get_image_asset(self, name: str) -> Image.Image: def get_image_asset(self, name: str) -> Image.Image:
return fetch_image(self.url_for(name)) image = fetch_image(self.url_for(name))
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
return image
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
...@@ -1457,3 +1477,14 @@ def clean_gpu_memory_between_tests(): ...@@ -1457,3 +1477,14 @@ def clean_gpu_memory_between_tests():
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
@pytest.fixture
def use_fresh_inductor_cache():
"""
Use a fresh inductor cache for the test.
This is useful to ensure that the test is not affected by the
previous test calls.
"""
with fresh_cache():
yield
...@@ -16,16 +16,35 @@ from typing import Literal, NamedTuple ...@@ -16,16 +16,35 @@ from typing import Literal, NamedTuple
import pytest import pytest
import torch import torch
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
from tests.utils import RemoteOpenAIServer, create_new_process_for_each_test
from vllm.config.model import RunnerOption from vllm.config.model import RunnerOption
from vllm.logger import init_logger from vllm.logger import init_logger
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test
logger = init_logger("test_context_parallel") logger = init_logger("test_context_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
CP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"deepseek-ai/DeepSeek-V2-Lite-Chat",
"Qwen/Qwen2.5-1.5B-Instruct",
]
# GSM8K eval configuration
NUM_QUESTIONS = 256 # Fast eval for CI
NUM_SHOTS = 5 # Few-shot examples
# tp accuracy with 2% buffer
MIN_ACCURACY = {
# .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.64,
# .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
"Qwen/Qwen2.5-1.5B-Instruct": 0.52,
}
class ParallelSetup(NamedTuple): class ParallelSetup(NamedTuple):
tp_size: int tp_size: int
...@@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple): ...@@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple):
class CPTestOptions(NamedTuple): class CPTestOptions(NamedTuple):
multi_node_only: bool multi_node_only: bool
load_format: str | None = None
attn_backend: str | None = None attn_backend: str | None = None
...@@ -54,17 +72,20 @@ class CPTestSettings: ...@@ -54,17 +72,20 @@ class CPTestSettings:
*, *,
tp_base: int = 4, tp_base: int = 4,
pp_base: int = 1, pp_base: int = 1,
dcp_base: int = 1, dcp_multipliers: list[float] | None = None,
cp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False, multi_node_only: bool = False,
runner: RunnerOption = "auto", runner: RunnerOption = "auto",
load_format: str | None = None,
attn_backend: str | None = None, attn_backend: str | None = None,
): ):
parallel_setups = [] parallel_setups = []
if dcp_multipliers is None:
dcp_multipliers = [
0.5,
]
for eager_mode_val in [False]: for eager_mode_val in [False]:
for pp_multiplier in [1]: for pp_multiplier in [1]:
for dcp_multiplier in [0.5, 1]: for dcp_multiplier in dcp_multipliers:
for chunked_prefill_val in [True]: for chunked_prefill_val in [True]:
parallel_setups.append( parallel_setups.append(
ParallelSetup( ParallelSetup(
...@@ -82,7 +103,6 @@ class CPTestSettings: ...@@ -82,7 +103,6 @@ class CPTestSettings:
runner=runner, runner=runner,
test_options=CPTestOptions( test_options=CPTestOptions(
multi_node_only=multi_node_only, multi_node_only=multi_node_only,
load_format=load_format,
attn_backend=attn_backend, attn_backend=attn_backend,
), ),
) )
...@@ -101,7 +121,27 @@ class CPTestSettings: ...@@ -101,7 +121,27 @@ class CPTestSettings:
) )
def _compare_cp_with_tp( CP_TEXT_GENERATION_MODELS = {
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(dcp_multipliers=[1]),
CPTestSettings.detailed(
dcp_multipliers=[0.5],
cp_kv_cache_interleave_size=64,
attn_backend="FLASHMLA",
),
],
"Qwen/Qwen2.5-1.5B-Instruct": [
CPTestSettings.detailed(
cp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN"
),
CPTestSettings.detailed(
cp_kv_cache_interleave_size=16, attn_backend="FLASHINFER"
),
],
}
def _test_cp_gsm8k(
model_id: str, model_id: str,
parallel_setup: ParallelSetup, parallel_setup: ParallelSetup,
distributed_backend: str, distributed_backend: str,
...@@ -121,7 +161,7 @@ def _compare_cp_with_tp( ...@@ -121,7 +161,7 @@ def _compare_cp_with_tp(
chunked_prefill, chunked_prefill,
) = parallel_setup ) = parallel_setup
multi_node_only, load_format, attn_backend = test_options multi_node_only, attn_backend = test_options
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
...@@ -130,22 +170,7 @@ def _compare_cp_with_tp( ...@@ -130,22 +170,7 @@ def _compare_cp_with_tp(
tokenizer_mode = model_info.tokenizer_mode tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides hf_overrides = model_info.hf_overrides
if load_format == "dummy": model_info.check_available_online(on_fail="skip")
# Avoid OOM
text_overrides = {
"num_hidden_layers": 4,
"hidden_size": 512,
"intermediate_size": 800,
"num_attention_heads": 4,
"num_key_value_heads": 1,
}
if is_multimodal:
hf_overrides.update({"text_config": text_overrides})
else:
hf_overrides.update(text_overrides)
else:
model_info.check_available_online(on_fail="skip")
if num_gpus_available < tp_size * pp_size: if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
...@@ -157,90 +182,70 @@ def _compare_cp_with_tp( ...@@ -157,90 +182,70 @@ def _compare_cp_with_tp(
if multi_node_only and not VLLM_MULTI_NODE: if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting") pytest.skip("Not in multi-node setting")
common_args = [ server_args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", "bfloat16",
"--max-model-len", "--max-model-len",
"2048", "4096",
"--max-num-seqs", "--max-num-seqs",
"8", "64",
] ]
if chunked_prefill: if chunked_prefill:
common_args.append("--enable-chunked-prefill") server_args.append("--enable-chunked-prefill")
if eager_mode: if eager_mode:
common_args.append("--enforce-eager") server_args.append("--enforce-eager")
if runner != "auto": if runner != "auto":
common_args.extend(["--runner", runner]) server_args.extend(["--runner", runner])
if trust_remote_code: if trust_remote_code:
common_args.append("--trust-remote-code") server_args.append("--trust-remote-code")
if tokenizer_mode: if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode]) server_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides: if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) server_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if not attn_backend: server_args.extend(
cp_env = tp_env = {} [
else: "--tensor-parallel-size",
cp_env = tp_env = { str(tp_size),
"VLLM_ATTENTION_BACKEND": attn_backend, "--pipeline-parallel-size",
} str(pp_size),
"--decode-context-parallel-size",
cp_args = [ str(dcp_size),
*common_args, "--dcp-kv-cache-interleave-size",
"--tensor-parallel-size", str(cp_kv_cache_interleave_size),
str(tp_size), "--distributed-executor-backend",
"--pipeline-parallel-size", distributed_backend,
str(pp_size), ]
"--decode-context-parallel-size", )
str(dcp_size),
"--dcp-kv-cache-interleave-size",
str(cp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
tp_args = [ server_env = {}
*common_args, if attn_backend:
"--tensor-parallel-size", server_env["VLLM_ATTENTION_BACKEND"] = attn_backend
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--distributed-executor-backend",
distributed_backend,
]
compare_two_settings( with RemoteOpenAIServer(
model_id, model_id,
cp_args, server_args,
tp_args, env_dict=server_env,
cp_env,
tp_env,
method=method,
max_wait_seconds=720, max_wait_seconds=720,
) ) as remote_server:
host = f"http://{remote_server.host}"
port = remote_server.port
CP_TEXT_GENERATION_MODELS = {
"deepseek-ai/DeepSeek-V2-Lite-Chat": [ # Run GSM8K evaluation
CPTestSettings.detailed(), results = evaluate_gsm8k(
CPTestSettings.detailed(tp_base=2), num_questions=NUM_QUESTIONS,
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64), num_shots=NUM_SHOTS,
], host=host,
"bigcode/gpt_bigcode-santacoder": [ port=port,
CPTestSettings.detailed(), )
CPTestSettings.detailed(tp_base=2),
],
}
CP_TEST_MODELS = [ # Validate accuracy is reasonable
# TODO support other models accuracy = results["accuracy"]
# [LANGUAGE GENERATION] min_accuracy = MIN_ACCURACY[model_id]
"deepseek-ai/DeepSeek-V2-Lite-Chat", assert accuracy >= min_accuracy, (
"bigcode/gpt_bigcode-santacoder", f"TP+DCP accuracy too low: {accuracy:.3f} < {min_accuracy:.3f}"
] )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -274,12 +279,12 @@ def test_cp_generation( ...@@ -274,12 +279,12 @@ def test_cp_generation(
): ):
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher") pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
if ( if (
model_id == "bigcode/gpt_bigcode-santacoder" model_id == "Qwen/Qwen2.5-1.5B-Instruct"
and torch.cuda.get_device_capability() != (9, 0) and torch.cuda.get_device_capability() != (9, 0)
): ):
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0") pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")
_compare_cp_with_tp( _test_cp_gsm8k(
model_id, model_id,
parallel_setup, parallel_setup,
distributed_backend, distributed_backend,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import pytest import pytest
import torch import torch
from vllm.distributed.eplb.rebalance_algo import rebalance_experts from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
def test_basic_rebalance(): def test_basic_rebalance():
...@@ -23,7 +23,7 @@ def test_basic_rebalance(): ...@@ -23,7 +23,7 @@ def test_basic_rebalance():
num_nodes = 2 num_nodes = 2
num_gpus = 8 num_gpus = 8
phy2log, log2phy, logcnt = rebalance_experts( phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
...@@ -77,7 +77,7 @@ def test_single_gpu_case(): ...@@ -77,7 +77,7 @@ def test_single_gpu_case():
num_nodes = 1 num_nodes = 1
num_gpus = 1 num_gpus = 1
phy2log, log2phy, logcnt = rebalance_experts( phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
...@@ -99,7 +99,7 @@ def test_equal_weights(): ...@@ -99,7 +99,7 @@ def test_equal_weights():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts( phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
...@@ -122,7 +122,7 @@ def test_extreme_weight_imbalance(): ...@@ -122,7 +122,7 @@ def test_extreme_weight_imbalance():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts( phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
...@@ -150,7 +150,7 @@ def test_multiple_layers(): ...@@ -150,7 +150,7 @@ def test_multiple_layers():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts( phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
...@@ -175,14 +175,14 @@ def test_parameter_validation(): ...@@ -175,14 +175,14 @@ def test_parameter_validation():
# Test non-divisible case - this should handle normally without throwing # Test non-divisible case - this should handle normally without throwing
# errors because the function will fall back to global load balancing # errors because the function will fall back to global load balancing
# strategy # strategy
phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4) phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
assert phy2log.shape == (1, 8) assert phy2log.shape == (1, 8)
assert logcnt.shape == (1, 4) assert logcnt.shape == (1, 4)
# Test cases that will actually cause errors: # Test cases that will actually cause errors:
# num_physical_experts not divisible by num_gpus # num_physical_experts not divisible by num_gpus
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4 DefaultEplbPolicy.rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
def test_small_scale_hierarchical(): def test_small_scale_hierarchical():
...@@ -197,7 +197,7 @@ def test_small_scale_hierarchical(): ...@@ -197,7 +197,7 @@ def test_small_scale_hierarchical():
num_nodes = 2 # 2 nodes num_nodes = 2 # 2 nodes
num_gpus = 4 # 4 GPUs num_gpus = 4 # 4 GPUs
phy2log, log2phy, logcnt = rebalance_experts( phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
...@@ -224,7 +224,7 @@ def test_global_load_balance_fallback(): ...@@ -224,7 +224,7 @@ def test_global_load_balance_fallback():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts( phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
...@@ -246,7 +246,7 @@ def test_device_compatibility(device): ...@@ -246,7 +246,7 @@ def test_device_compatibility(device):
num_nodes = 1 num_nodes = 1
num_gpus = 2 num_gpus = 2
phy2log, log2phy, logcnt = rebalance_experts( phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
...@@ -263,7 +263,9 @@ def test_additional_cases(): ...@@ -263,7 +263,9 @@ def test_additional_cases():
weight1 = torch.tensor( weight1 = torch.tensor(
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]] [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
) )
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts(
weight1, 24, 8, 4, 8
)
assert phy2log1.shape == (1, 24) assert phy2log1.shape == (1, 24)
assert logcnt1.shape == (1, 16) assert logcnt1.shape == (1, 16)
...@@ -276,7 +278,9 @@ def test_additional_cases(): ...@@ -276,7 +278,9 @@ def test_additional_cases():
[12, 25, 50, 100, 150, 200], # Increasing weights [12, 25, 50, 100, 150, 200], # Increasing weights
] ]
) )
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts(
weight2, 10, 3, 1, 2
)
assert phy2log2.shape == (2, 10) assert phy2log2.shape == (2, 10)
assert logcnt2.shape == (2, 6) assert logcnt2.shape == (2, 6)
...@@ -300,7 +304,7 @@ if __name__ == "__main__": ...@@ -300,7 +304,7 @@ if __name__ == "__main__":
num_nodes = 2 num_nodes = 2
num_gpus = 8 num_gpus = 8
phy2log, log2phy, logcnt = rebalance_experts( phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
print(phy2log) print(phy2log)
......
...@@ -6,6 +6,7 @@ import lm_eval ...@@ -6,6 +6,7 @@ import lm_eval
import pytest import pytest
from tests.utils import large_gpu_mark from tests.utils import large_gpu_mark
from vllm.platforms import current_platform
def get_model_args( def get_model_args(
...@@ -45,6 +46,12 @@ def get_model_args( ...@@ -45,6 +46,12 @@ def get_model_args(
return model_args return model_args
pytestmark = pytest.mark.skipif(
current_platform.is_rocm(),
reason="EPLB with Spec Decode is a work in progress on ROCm.",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_setup", "model_setup",
[ [
......
...@@ -61,7 +61,7 @@ def test_get_kv_connector_cache_layout_with_multi_connector(): ...@@ -61,7 +61,7 @@ def test_get_kv_connector_cache_layout_with_multi_connector():
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={ kv_connector_extra_config={
"connectors": [ "connectors": [
{"kv_connector": "SharedStorageConnector", "kv_role": "kv_both"}, {"kv_connector": "ExampleConnector", "kv_role": "kv_both"},
{"kv_connector": "NixlConnector", "kv_role": "kv_both"}, {"kv_connector": "NixlConnector", "kv_role": "kv_both"},
] ]
}, },
......
...@@ -109,7 +109,7 @@ TEXT_GENERATION_MODELS = { ...@@ -109,7 +109,7 @@ TEXT_GENERATION_MODELS = {
"baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(), "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(),
"bigscience/bloomz-1b1": PPTestSettings.fast(), "bigscience/bloomz-1b1": PPTestSettings.fast(),
"zai-org/chatglm3-6b": PPTestSettings.fast(), "zai-org/chatglm3-6b": PPTestSettings.fast(),
"CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"), "CohereLabs/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"),
"databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"), "databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(), "Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
"deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(), "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
......
...@@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int): ...@@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int):
modality=modality, modality=modality,
key=key, key=key,
data=torch.empty((size,), dtype=torch.int8), data=torch.empty((size,), dtype=torch.int8),
field=MultiModalSharedField(1), field=MultiModalSharedField(batch_size=1),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment