Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
KVTransferConfig, KVTransferConfig,
...@@ -94,6 +95,7 @@ def create_vllm_config( ...@@ -94,6 +95,7 @@ def create_vllm_config(
dtype: str = "float16", dtype: str = "float16",
cache_dtype: str = "auto", cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None, hf_overrides: dict[str, Any] | None = None,
attention_backend: str | None = None,
) -> VllmConfig: ) -> VllmConfig:
"""Initialize VllmConfig For Testing.""" """Initialize VllmConfig For Testing."""
model_config = ModelConfig( model_config = ModelConfig(
...@@ -131,12 +133,14 @@ def create_vllm_config( ...@@ -131,12 +133,14 @@ def create_vllm_config(
enable_permute_local_kv=enable_permute_local_kv, enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {}, kv_connector_extra_config=kv_connector_extra_config or {},
) )
attention_config = AttentionConfig(backend=attention_backend)
return VllmConfig( return VllmConfig(
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
kv_transfer_config=kv_transfer_config, kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"), device_config=DeviceConfig("cpu"),
attention_config=attention_config,
) )
......
...@@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt ...@@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
CPU_BLOCK_SIZES = [48] CPU_BLOCK_SIZES = [48]
ATTN_BACKENDS = ["FLASH_ATTN"] ATTN_BACKENDS = ["FLASH_ATTN"]
...@@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None: ...@@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
topic="test", topic="test",
) )
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend): llm = LLM(
llm = LLM( model="meta-llama/Llama-3.2-1B-Instruct",
model="meta-llama/Llama-3.2-1B-Instruct", gpu_memory_utilization=0.5,
gpu_memory_utilization=0.5, kv_events_config=kv_events_config,
kv_events_config=kv_events_config, kv_transfer_config=kv_transfer_config,
kv_transfer_config=kv_transfer_config, attention_config={"backend": attn_backend},
) )
events_endpoint = events_endpoint.replace("*", "127.0.0.1") events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic) subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the analytic estimators in metrics/flops.py.
"""
import types
from types import SimpleNamespace
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
from transformers.models.llama4.configuration_llama4 import (
Llama4Config,
Llama4TextConfig,
)
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
from vllm.config.model import ModelConfig, get_hf_text_config
from vllm.v1.metrics.perf import (
AttentionMetrics,
BaseConfigParser,
ExecutionContext,
FfnMetrics,
ModelMetrics,
ParsedArgs,
UnembedMetrics,
)
class MockModelConfig:
"""Mock ModelConfig that implements the getter methods used by parsers."""
def __init__(self, hf_config, dtype):
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(hf_config)
self.dtype = dtype
self.is_attention_free = False
def __getattr__(self, name):
# 1. Check if ModelConfig actually has this attribute
if not hasattr(ModelConfig, name):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}' "
f"and neither does 'ModelConfig'."
)
# 2. Fetch the attribute from the ModelConfig CLASS
attr = getattr(ModelConfig, name)
# 3. Case A: It is a @property
if isinstance(attr, property):
# Manually invoke the property's getter, passing 'self' (this mock instance)
return attr.__get__(self, self.__class__)
# 4. Case B: It is a standard method (function)
if isinstance(attr, types.FunctionType):
# Bind the function to 'self' so it acts like a method of
# this instance. This creates a bound method where 'self' is
# automatically passed as the first arg.
return types.MethodType(attr, self)
# 5. Case C: It is a class attribute / static variable
return attr
def create_mock_vllm_config(
hf_config,
model_dtype="bfloat16",
cache_dtype="auto",
quant_config=None,
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
enable_expert_parallel=False,
) -> SimpleNamespace:
vllm_config = SimpleNamespace()
vllm_config.model_config = MockModelConfig(hf_config, model_dtype)
vllm_config.cache_config = SimpleNamespace()
vllm_config.cache_config.cache_dtype = cache_dtype
vllm_config.quant_config = quant_config
vllm_config.parallel_config = SimpleNamespace()
vllm_config.parallel_config.data_parallel_size = data_parallel_size
vllm_config.parallel_config.tensor_parallel_size = tensor_parallel_size
vllm_config.parallel_config.pipeline_parallel_size = pipeline_parallel_size
vllm_config.parallel_config.enable_expert_parallel = enable_expert_parallel
return vllm_config
#### Parser Tests ####
def test_base_config_parser():
"""Test BaseConfigParser extracts base model attributes correctly."""
hf_config = Qwen3Config(
vocab_size=50000,
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=24,
)
vllm_config = create_mock_vllm_config(hf_config, model_dtype="float16")
parser = BaseConfigParser()
args = ParsedArgs()
result = parser.parse(args, vllm_config)
assert result.vocab_size == 50000
assert result.hidden_size == 2048
assert result.num_attention_heads == 16
assert result.num_hidden_layers == 24
assert result.weight_byte_size == 2 # float16 is 2 bytes
assert result.activation_byte_size == 2 # default activation size
def test_base_attention_config_parser_with_gqa():
"""Test BaseAttentionConfigParser with grouped query attention."""
hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=8, # GQA with 4:1 ratio
head_dim=128,
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = AttentionMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.num_key_value_heads == 8
assert result.head_dim == 128
def test_base_attention_config_parser_without_gqa():
"""
Test BaseAttentionConfigParser defaults to MHA when num_key_value_heads not
specified.
"""
hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
# No num_key_value_heads specified
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = AttentionMetrics.get_parser()
result = parser_chain.parse(vllm_config)
# Should default to MHA (num_key_value_heads = num_attention_heads)
assert result.num_key_value_heads == 32
def test_base_ffn_config_parser_dense():
"""Test BaseFfnConfigParser for dense FFN."""
hf_config = Qwen3Config(
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.intermediate_size == 11008
assert result.num_experts == 0
assert result.num_experts_per_tok == 0
assert result.num_moe_layers == 0 # No MoE
def test_base_ffn_config_parser_moe():
"""Test BaseFfnConfigParser for MoE FFN."""
hf_config = Qwen3MoeConfig(
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_experts=64,
num_experts_per_tok=8,
moe_intermediate_size=14336,
n_shared_experts=2,
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.num_experts == 64
assert result.num_experts_per_tok == 8
assert result.moe_intermediate_size == 14336
assert result.num_shared_experts == 2
assert result.num_moe_layers == 32 # All layers are MoE by default
def test_interleave_moe_layer_step_parser():
"""Test InterleaveMoeLayerStepParser correctly computes MoE layer count."""
hf_config = Llama4Config(
text_config=Llama4TextConfig(
num_hidden_layers=32,
num_local_experts=64,
interleave_moe_layer_step=4, # Every 4th layer is MoE
),
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.num_moe_layers == 8
def test_moe_layer_freq_parser():
"""Test MoeLayerFreqParser correctly computes MoE layer count."""
hf_config = DeepseekV3Config(
num_hidden_layers=30,
n_routed_experts=64,
moe_layer_freq=3, # Every 3rd layer after first_k_dense_replace
first_k_dense_replace=6, # First 6 layers are dense
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
# Layers >= 6 and divisible by 3: 6, 9, 12, 15, 18, 21, 24, 27
expected_moe_layers = len(
[layer for layer in range(30) if layer >= 6 and layer % 3 == 0]
)
assert expected_moe_layers == 8
assert result.num_moe_layers == expected_moe_layers
#### ComponentMetrics Tests ####
def test_attention_metrics_scaling():
"""Test that attention metrics scale proportionally with model dimensions."""
base_hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_key_value_heads=16,
num_hidden_layers=12,
head_dim=128,
)
base_vllm_config = create_mock_vllm_config(base_hf_config)
base_metrics = AttentionMetrics.from_vllm_config(base_vllm_config)
# Test scaling with number of layers
double_layers_hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_key_value_heads=16,
num_hidden_layers=24, # Double the layers
head_dim=128,
)
double_layers_vllm_config = create_mock_vllm_config(double_layers_hf_config)
double_layers_metrics = AttentionMetrics.from_vllm_config(double_layers_vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# FLOPS should double when layers double
base_flops = base_metrics.get_num_flops(ctx)
double_flops = double_layers_metrics.get_num_flops(ctx)
assert double_flops == 2 * base_flops
# Read/write bytes should also scale proportionally
base_read = base_metrics.get_read_bytes(ctx)
double_read = double_layers_metrics.get_read_bytes(ctx)
assert double_read == 2 * base_read
base_write = base_metrics.get_write_bytes(ctx)
double_write = double_layers_metrics.get_write_bytes(ctx)
assert double_write == 2 * base_write
def test_attention_metrics_grouped_query():
"""Test attention metrics handle grouped query attention correctly."""
mha_hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=32, # MHA
num_hidden_layers=1,
)
mha_config = create_mock_vllm_config(mha_hf_config)
gqa_hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=8, # GQA with 4:1 ratio
num_hidden_layers=1,
)
gqa_config = create_mock_vllm_config(gqa_hf_config)
mha_metrics = AttentionMetrics.from_vllm_config(mha_config)
gqa_metrics = AttentionMetrics.from_vllm_config(gqa_config)
ctx = ExecutionContext.from_single_request(
num_tokens=1, context_len=1024, is_prefill=False
)
# GQA should have less KV cache reads since fewer KV heads
mha_read = mha_metrics.get_read_bytes(ctx)
gqa_read = gqa_metrics.get_read_bytes(ctx)
assert gqa_read < mha_read
def test_ffn_metrics_scaling():
"""Test FFN metrics scale proportionally with model dimensions."""
base_hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
)
base_vllm_config = create_mock_vllm_config(base_hf_config)
base_metrics = FfnMetrics.from_vllm_config(base_vllm_config)
# Test scaling with intermediate size
larger_ffn_hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=16384, # Double intermediate size
num_hidden_layers=12,
)
larger_ffn_vllm_config = create_mock_vllm_config(larger_ffn_hf_config)
larger_ffn_metrics = FfnMetrics.from_vllm_config(larger_ffn_vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# FLOPS should double when intermediate size doubles
base_flops = base_metrics.get_num_flops(ctx)
larger_flops = larger_ffn_metrics.get_num_flops(ctx)
assert larger_flops == base_flops * 2
def test_moe_metrics_vs_dense():
"""Test MoE metrics versus dense metrics."""
dense_hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
)
dense_config = create_mock_vllm_config(dense_hf_config)
moe_hf_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=2, # 2 routed expert
moe_intermediate_size=8192,
n_shared_experts=0,
)
moe_config = create_mock_vllm_config(moe_hf_config)
dense_metrics = FfnMetrics.from_vllm_config(dense_config)
moe_metrics = FfnMetrics.from_vllm_config(moe_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# MoE should have different compute/memory characteristics
dense_flops = dense_metrics.get_num_flops(ctx)
moe_flops = moe_metrics.get_num_flops(ctx)
# 2 routed experts vs 1 dense.
assert moe_flops == dense_flops * 2
def test_unembed_metrics_scaling():
"""Test unembedding metrics scale with vocab size."""
small_vocab_hf_config = Qwen3Config(
hidden_size=2048,
vocab_size=32000,
)
small_vocab_config = create_mock_vllm_config(small_vocab_hf_config)
large_vocab_hf_config = Qwen3Config(
hidden_size=2048,
vocab_size=64000, # Double vocab size
)
large_vocab_config = create_mock_vllm_config(large_vocab_hf_config)
small_vocab_metrics = UnembedMetrics.from_vllm_config(small_vocab_config)
large_vocab_metrics = UnembedMetrics.from_vllm_config(large_vocab_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# FLOPS should double when vocab size doubles
small_flops = small_vocab_metrics.get_num_flops(ctx)
large_flops = large_vocab_metrics.get_num_flops(ctx)
assert large_flops == 2 * small_flops
def test_prefill_vs_decode_differences():
"""Test that prefill and decode have different memory access patterns."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_key_value_heads=16,
num_hidden_layers=1,
)
config = create_mock_vllm_config(hf_config)
metrics = AttentionMetrics.from_vllm_config(config)
prefill_ctx = ExecutionContext.from_single_request(
num_tokens=512, context_len=512, is_prefill=True
)
decode_ctx = ExecutionContext.from_single_request(
num_tokens=1, context_len=512, is_prefill=False
)
prefill_read = metrics.get_read_bytes(prefill_ctx)
decode_read = metrics.get_read_bytes(decode_ctx)
assert prefill_read != decode_read
def test_model_metrics_aggregation():
"""Test ModelMetrics correctly aggregates across components."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=12,
vocab_size=32000,
intermediate_size=8192,
)
config = create_mock_vllm_config(hf_config)
model_metrics = ModelMetrics(config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Should have metrics for attention, ffn, and unembed
total_flops = model_metrics.get_num_flops(ctx)
breakdown = model_metrics.get_num_flops_breakdown(ctx)
# Breakdown should sum to total
assert total_flops == sum(breakdown.values())
def test_moe_expert_activation_proportional_scaling():
"""Test that routed expert metrics scale proportionally with num_experts_per_tok."""
base_moe_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=1, # 1 expert per token
moe_intermediate_size=8192,
n_shared_experts=2,
)
double_experts_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=2, # 2 experts per token (double)
moe_intermediate_size=8192,
n_shared_experts=2, # Same shared experts
)
triple_experts_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=3, # 3 experts per token (triple)
moe_intermediate_size=8192,
n_shared_experts=2, # Same shared experts
)
base_vllm_config = create_mock_vllm_config(base_moe_config)
double_vllm_config = create_mock_vllm_config(double_experts_config)
triple_vllm_config = create_mock_vllm_config(triple_experts_config)
base_metrics = FfnMetrics.from_vllm_config(base_vllm_config)
double_metrics = FfnMetrics.from_vllm_config(double_vllm_config)
triple_metrics = FfnMetrics.from_vllm_config(triple_vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get total metrics - the key insight is that differences should be proportional
base_flops = base_metrics.get_num_flops(ctx)
double_flops = double_metrics.get_num_flops(ctx)
triple_flops = triple_metrics.get_num_flops(ctx)
# The difference between double and base should equal one additional expert
one_expert_diff = double_flops - base_flops
# The difference between triple and base should equal two additional experts
two_expert_diff = triple_flops - base_flops
# Proportional scaling: 2 * (1 expert diff) should equal (2 expert diff)
assert two_expert_diff == 2 * one_expert_diff
# Same logic applies to memory operations
base_read = base_metrics.get_read_bytes(ctx)
double_read = double_metrics.get_read_bytes(ctx)
triple_read = triple_metrics.get_read_bytes(ctx)
one_expert_read_diff = double_read - base_read
two_expert_read_diff = triple_read - base_read
assert two_expert_read_diff == 2 * one_expert_read_diff
# Same for write bytes
base_write = base_metrics.get_write_bytes(ctx)
double_write = double_metrics.get_write_bytes(ctx)
triple_write = triple_metrics.get_write_bytes(ctx)
one_expert_write_diff = double_write - base_write
two_expert_write_diff = triple_write - base_write
assert two_expert_write_diff == 2 * one_expert_write_diff
def test_quantization_config_parser_fp8():
"""Test quantization parsers with fp8."""
class MockQuantConfig:
def get_name(self):
return "fp8"
hf_config = Qwen3Config(
hidden_size=2048, num_attention_heads=16, num_hidden_layers=1
)
vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig())
attn_result = AttentionMetrics.get_parser().parse(vllm_config)
assert attn_result.weight_byte_size == 1 # fp8
ffn_result = FfnMetrics.get_parser().parse(vllm_config)
assert ffn_result.weight_byte_size == 1 # fp8
def test_quantization_config_parser_mxfp4():
"""Test quantization parsers with mxfp4."""
class MockQuantConfig:
def get_name(self):
return "mxfp4"
hf_config = Qwen3Config(
hidden_size=2048, intermediate_size=8192, num_hidden_layers=1
)
vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig())
ffn_result = FfnMetrics.get_parser().parse(vllm_config)
assert ffn_result.weight_byte_size == 0.5 # mxfp4
#### Per-GPU Tests ####
def test_attention_per_gpu_with_tensor_parallelism():
"""Test attention metrics with tensor parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=8,
num_hidden_layers=24,
)
# Test with TP=4
vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=4)
metrics = AttentionMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=128, context_len=1024, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With TP=4, global flops should be 4x per-gpu flops (heads divided by 4)
assert global_flops == 4 * per_gpu_flops
# Same for read/write bytes
global_read = metrics.get_read_bytes(ctx, per_gpu=False)
per_gpu_read = metrics.get_read_bytes(ctx, per_gpu=True)
# Reads should scale similarly (weight reads are divided by TP)
assert global_read > per_gpu_read
global_write = metrics.get_write_bytes(ctx, per_gpu=False)
per_gpu_write = metrics.get_write_bytes(ctx, per_gpu=True)
assert global_write > per_gpu_write
def test_attention_per_gpu_with_pipeline_parallelism():
"""Test attention metrics with pipeline parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=32,
)
# Test with PP=4
vllm_config = create_mock_vllm_config(hf_config, pipeline_parallel_size=4)
metrics = AttentionMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=False
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With PP=4, global flops should be 4x per-gpu flops (layers divided by 4)
assert global_flops == 4 * per_gpu_flops
global_read = metrics.get_read_bytes(ctx, per_gpu=False)
per_gpu_read = metrics.get_read_bytes(ctx, per_gpu=True)
assert global_read == 4 * per_gpu_read
def test_ffn_per_gpu_with_tensor_parallelism():
"""Test FFN metrics with tensor parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
)
# Test with DP=2, TP=4 (ffn_tp_size will be 8)
vllm_config = create_mock_vllm_config(
hf_config,
data_parallel_size=2,
tensor_parallel_size=4,
)
metrics = FfnMetrics.from_vllm_config(vllm_config)
# ffn_tp_size should be dp_size * tp_size = 8 (when EP not enabled)
assert metrics.ffn_tp_size == 8
ctx = ExecutionContext.from_single_request(
num_tokens=128, context_len=2048, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With ffn_tp_size=8, global should be 8x per-gpu
assert global_flops == 8 * per_gpu_flops
def test_ffn_per_gpu_with_pipeline_parallelism():
"""Test FFN metrics with pipeline parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
)
# Test with PP=6
vllm_config = create_mock_vllm_config(hf_config, pipeline_parallel_size=6)
metrics = FfnMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With PP=6, global should be 6x per-gpu (layers divided by 6)
assert global_flops == 6 * per_gpu_flops
def test_moe_per_gpu_with_expert_parallelism():
"""
Test MoE metrics with expert parallelism - verifies num_activated_experts bug fix.
"""
hf_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
num_experts=64,
num_experts_per_tok=8,
moe_intermediate_size=14336,
n_shared_experts=2,
)
# Test with DP=2, TP=4, EP enabled (ffn_ep_size will be 8)
vllm_config = create_mock_vllm_config(
hf_config,
data_parallel_size=2,
tensor_parallel_size=4,
enable_expert_parallel=True,
)
metrics = FfnMetrics.from_vllm_config(vllm_config)
# When EP enabled, ffn_ep_size = dp_size * tp_size = 8
assert metrics.ffn_ep_size == 8
assert metrics.ffn_tp_size == 1
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get per-gpu metrics
per_gpu_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=True)
global_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=False)
# Verify that routed expert weight reads are reasonable
# With per_gpu=True, each GPU has 64/8 = 8 experts
# T=100, E_per_gpu=8/8=1, so T*E=100 expert activations
# num_activated_experts should be min(100, 8) = 8
# Check that weight reads scale appropriately
# Global has all 64 experts, per-gpu has 8 experts
# So weight reads should reflect this difference
if "routed_up_gate_weights" in per_gpu_read_breakdown:
per_gpu_weight_reads = per_gpu_read_breakdown["routed_up_gate_weights"]
global_weight_reads = global_read_breakdown["routed_up_gate_weights"]
# The ratio should reflect the expert count difference
# This verifies the bug fix works correctly
assert per_gpu_weight_reads < global_weight_reads
# Global should read more experts than per-gpu
# Exact ratio depends on num_activated_experts calculation
ratio = global_weight_reads / per_gpu_weight_reads
# Should be > 1 since global has more experts to read
assert ratio > 1
def test_moe_per_gpu_expert_activation_accounting():
"""
Test that MoE correctly accounts for expert activations with small batch sizes.
"""
hf_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=8,
moe_intermediate_size=14336,
n_shared_experts=0, # No shared experts for this test
)
# Test with EP=8
vllm_config = create_mock_vllm_config(
hf_config,
data_parallel_size=8,
enable_expert_parallel=True,
)
metrics = FfnMetrics.from_vllm_config(vllm_config)
# Small batch: T=10, E_per_gpu=8/8=1
# Each GPU: T*E = 10*1 = 10 activations
# Experts per GPU: 64/8 = 8
# So num_activated_experts should be min(10, 8) = 8
small_ctx = ExecutionContext.from_single_request(
num_tokens=10, context_len=512, is_prefill=True
)
small_read = metrics.get_read_bytes_breakdown(small_ctx, per_gpu=True)
# Large batch: T=1000, E_per_gpu=1
# Each GPU: T*E = 1000*1 = 1000 activations
# Experts per GPU: 8
# So num_activated_experts should be min(1000, 8) = 8 (all experts activated)
large_ctx = ExecutionContext.from_single_request(
num_tokens=1000, context_len=512, is_prefill=True
)
large_read = metrics.get_read_bytes_breakdown(large_ctx, per_gpu=True)
# Weight reads should be similar (both activate all 8 experts per GPU)
# But activation reads should differ (proportional to T*E)
if "routed_up_gate_weights" in small_read:
small_weight = small_read["routed_up_gate_weights"]
large_weight = large_read["routed_up_gate_weights"]
# Weight reads should be the same (both read all 8 experts)
assert small_weight == large_weight
# But input activation reads should scale with T*E
small_input = small_read["routed_up_gate_input"]
large_input = large_read["routed_up_gate_input"]
assert large_input == 100 * small_input # 1000/10 = 100x
def test_unembed_per_gpu_with_tensor_parallelism():
"""Test unembed metrics with tensor parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=4096,
vocab_size=128000,
)
# Test with TP=8
vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=8)
metrics = UnembedMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With TP=8, vocab is divided by 8, so global should be 8x per-gpu
assert global_flops == 8 * per_gpu_flops
# For read bytes, weight reads scale with TP but input reads don't (replicated)
global_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=False)
per_gpu_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=True)
# Input reads should be the same (replicated across TP ranks)
assert global_read_breakdown["input"] == per_gpu_read_breakdown["input"]
# Weight reads should scale 8x (divided by TP)
assert global_read_breakdown["weight"] == 8 * per_gpu_read_breakdown["weight"]
def test_model_metrics_per_gpu_aggregation():
"""Test ModelMetrics correctly aggregates per_gpu metrics across components."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=12,
vocab_size=32000,
intermediate_size=8192,
)
# Test with mixed parallelism: TP=2, PP=2
vllm_config = create_mock_vllm_config(
hf_config,
tensor_parallel_size=2,
pipeline_parallel_size=2,
)
model_metrics = ModelMetrics(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get breakdowns for both modes
per_gpu_breakdown = model_metrics.get_num_flops_breakdown(ctx, per_gpu=True)
global_breakdown = model_metrics.get_num_flops_breakdown(ctx, per_gpu=False)
# Verify breakdown sums match totals
per_gpu_total = model_metrics.get_num_flops(ctx, per_gpu=True)
global_total = model_metrics.get_num_flops(ctx, per_gpu=False)
assert per_gpu_total == sum(per_gpu_breakdown.values())
assert global_total == sum(global_breakdown.values())
# Global should be larger than per-gpu due to parallelism
assert global_total > per_gpu_total
# With TP=2 and PP=2, the ratio depends on which parallelism applies to
# which component but we can verify that global is reasonably larger
ratio = global_total / per_gpu_total
assert ratio > 1 # Should be between PP and TP*PP depending on component mix
def test_attention_per_gpu_heads_not_evenly_divisible():
"""Test attention with heads not evenly divisible by TP."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=17, # Not divisible by 4
num_key_value_heads=5, # Not divisible by 4
num_hidden_layers=8,
)
vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=4)
metrics = AttentionMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=64, context_len=256, is_prefill=True
)
# Should not crash and should handle max(1, ...) correctly
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
# Both should be positive
assert per_gpu_flops > 0
assert global_flops > 0
assert global_flops > per_gpu_flops
...@@ -15,6 +15,7 @@ from tests.v1.attention.utils import ( ...@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
) )
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
ModelConfig, ModelConfig,
...@@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" ...@@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def _create_proposer( def _create_proposer(
method: str, method: str,
num_speculative_tokens: int, num_speculative_tokens: int,
attention_backend: str | None = None,
speculative_token_tree: list[tuple[int, ...]] | None = None, speculative_token_tree: list[tuple[int, ...]] | None = None,
) -> EagleProposer: ) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
...@@ -70,6 +72,7 @@ def _create_proposer( ...@@ -70,6 +72,7 @@ def _create_proposer(
max_model_len=model_config.max_model_len, max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder, is_encoder_decoder=model_config.is_encoder_decoder,
), ),
attention_config=AttentionConfig(backend=attention_backend),
) )
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
...@@ -331,8 +334,6 @@ def test_load_model( ...@@ -331,8 +334,6 @@ def test_load_model(
use_distinct_lm_head, use_distinct_lm_head,
monkeypatch, monkeypatch,
): ):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip( pytest.skip(
"TRITON_ATTN does not support " "TRITON_ATTN does not support "
...@@ -396,7 +397,9 @@ def test_load_model( ...@@ -396,7 +397,9 @@ def test_load_model(
assert not isinstance(target_model, SupportsMultiModal) assert not isinstance(target_model, SupportsMultiModal)
# Create proposer using the helper function # Create proposer using the helper function
proposer = _create_proposer(method, num_speculative_tokens=8) proposer = _create_proposer(
method, num_speculative_tokens=8, attention_backend=attn_backend
)
# Call the method under test # Call the method under test
proposer.load_model(target_model) proposer.load_model(target_model)
...@@ -422,8 +425,6 @@ def test_load_model( ...@@ -422,8 +425,6 @@ def test_load_model(
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip( pytest.skip(
"TRITON_ATTN does not support " "TRITON_ATTN does not support "
...@@ -451,7 +452,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -451,7 +452,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
seq_lens = [seq_len_1, seq_len_2] seq_lens = [seq_len_1, seq_len_2]
# Create proposer first so we can use its actual hidden_size # Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens) proposer = _create_proposer(
"eagle", num_speculative_tokens, attention_backend=attn_backend
)
# Get the hidden_size from the proposer to ensure consistency # Get the hidden_size from the proposer to ensure consistency
hidden_size = proposer.hidden_size hidden_size = proposer.hidden_size
...@@ -624,7 +627,9 @@ def test_propose_tree(spec_token_tree): ...@@ -624,7 +627,9 @@ def test_propose_tree(spec_token_tree):
# Create proposer first so we can use its actual hidden_size. # Create proposer first so we can use its actual hidden_size.
proposer = _create_proposer( proposer = _create_proposer(
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree "eagle",
num_speculative_tokens,
speculative_token_tree=spec_token_tree,
) )
# Get the hidden_size from the proposer to ensure consistency. # Get the hidden_size from the proposer to ensure consistency.
hidden_size = proposer.hidden_size hidden_size = proposer.hidden_size
......
...@@ -38,53 +38,48 @@ def test_ngram_max_len(num_speculative_tokens: int): ...@@ -38,53 +38,48 @@ def test_ngram_max_len(num_speculative_tokens: int):
def test_eagle_max_len( def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str
): ):
with monkeypatch.context() as m: if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) pytest.skip(
"TRITON_ATTN does not support "
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): "multi-token eagle spec decode on current platform"
pytest.skip( )
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM( llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct", model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization. enforce_eager=True, # For faster initialization.
speculative_config={ speculative_config={
"method": "eagle", "method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens, "num_speculative_tokens": num_speculative_tokens,
"max_model_len": 80, "max_model_len": 80,
}, },
max_model_len=200, max_model_len=200,
attention_config={"backend": attn_backend},
)
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
outputs = llm.generate(_PROMPTS, sampling_params)
for o in outputs:
assert o.outputs[0].finish_reason == "length", (
"This test is only meaningful if the output is truncated due to max length"
) )
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
outputs = llm.generate(_PROMPTS, sampling_params)
for o in outputs:
assert o.outputs[0].finish_reason == "length", (
"This test is only meaningful if the output "
"is truncated due to max length"
)
sampling_params = SamplingParams( sampling_params = SamplingParams(
max_tokens=200, max_tokens=200,
structured_outputs=StructuredOutputsParams( structured_outputs=StructuredOutputsParams(regex="^" + "a b c d e " * 15 + "$"),
regex="^" + "a b c d e " * 15 + "$" )
), output = llm.generate(_PROMPTS, sampling_params)
for o in output:
assert o.prompt_token_ids is not None
assert (
len(o.prompt_token_ids)
< 80
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
<= 200
), (
"This test is only meaningful if the output "
"is longer than the eagle max length"
) )
output = llm.generate(_PROMPTS, sampling_params) assert o.outputs[0].text == "a b c d e " * 15
for o in output:
assert o.prompt_token_ids is not None
assert (
len(o.prompt_token_ids)
< 80
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
<= 200
), (
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
assert o.outputs[0].text == "a b c d e " * 15
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
""" """
Test: Test:
* Tests for MultiHeadAttention layer * Tests for MMEncoderAttention layer
""" """
import pytest import pytest
...@@ -12,7 +12,7 @@ import torch_xla ...@@ -12,7 +12,7 @@ import torch_xla
import torch_xla.core import torch_xla.core
import torch_xla.core.xla_model import torch_xla.core.xla_model
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.selector import _cached_get_attn_backend from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -69,7 +69,7 @@ def test_mha_attn_forward( ...@@ -69,7 +69,7 @@ def test_mha_attn_forward(
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
scale = 1.0 / head_size**0.5 scale = 1.0 / head_size**0.5
attn = MultiHeadAttention( attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
) )
output = attn(q, k, v) output = attn(q, k, v)
......
...@@ -1110,3 +1110,87 @@ def test_hybrid_cache_integration(model_runner, dist_init): ...@@ -1110,3 +1110,87 @@ def test_hybrid_cache_integration(model_runner, dist_init):
runner._update_states(scheduler_output) runner._update_states(scheduler_output)
assert _is_req_scheduled(runner, req_id) assert _is_req_scheduled(runner, req_id)
assert _is_req_state_block_table_match(runner, req_id) assert _is_req_state_block_table_match(runner, req_id)
def test_is_uniform_decode() -> None:
# Normal
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
)
# Spec decoding
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=5,
num_tokens=30,
num_reqs=6,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=4,
num_tokens=30,
num_reqs=6,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=5,
num_tokens=30,
num_reqs=7,
)
# Force uniform decode
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=True,
)
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=True,
)
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
force_uniform_decode=True,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=False,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=False,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
force_uniform_decode=False,
)
...@@ -24,14 +24,13 @@ def is_aiter_found() -> bool: ...@@ -24,14 +24,13 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks. # we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found() IS_AITER_FOUND = is_aiter_found()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if IS_AITER_FOUND:
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8 def is_aiter_found_and_supported() -> bool:
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
return on_gfx9()
return False
def if_aiter_supported(func: Callable) -> Callable: def if_aiter_supported(func: Callable) -> Callable:
...@@ -43,17 +42,24 @@ def if_aiter_supported(func: Callable) -> Callable: ...@@ -43,17 +42,24 @@ def if_aiter_supported(func: Callable) -> Callable:
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
# checks the platform, device arch and aiter library existence. # checks the platform, device arch and aiter library existence.
if current_platform.is_rocm() and IS_AITER_FOUND: if is_aiter_found_and_supported():
from vllm.platforms.rocm import on_gfx9 return func(*args, **kwargs)
if on_gfx9():
return func(*args, **kwargs)
return None return None
return wrapper return wrapper
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if is_aiter_found_and_supported():
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def _rocm_aiter_fused_moe_impl( def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
...@@ -642,48 +648,130 @@ _OPS_REGISTERED = False ...@@ -642,48 +648,130 @@ _OPS_REGISTERED = False
class rocm_aiter_ops: class rocm_aiter_ops:
"""ROCm AITER operations wrapper for AMD GPU acceleration in vLLM.
This class centralizes the import and registration of AITER ops,
and provides a unified interface for checking if AITER is enabled.
Operations are only available on supported gfx9
architectures when aiter is installed.
The class uses environment variables to control which features are enabled,
allowing fine-grained control over which AITER optimizations are used.
Environment Variables:
VLLM_ROCM_USE_AITER: Main toggle for all AITER operations.
VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops.
VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations.
VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops.
VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops.
VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen.
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention.
VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply.
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM.
VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings.
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion.
VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM.
Note:
The environment variables are assigned when the module is imported,
so you can't change the environment variables after the module is imported.
This is done out of performance consideration. Accessing environment variables
is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067
so we don't want to do it repeatedly, especially in the hot path (the forward pass).
You can call the refresh_env_variables() function to reload the env variables
after monkey patching the env variables in the unit test.
Check Functions:
All check functions (is_*_enabled) are decorated with @if_aiter_supported,
which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
(3) aiter library is installed. The check function then also verifies
the corresponding environment variable is enabled.
i.e. ___
is_enabled() == current_platform.is_rocm() and | checked by
current_platform.is_on_gfx9() and | @if_aiter_supported
IS_AITER_FOUND and _______________|
cls._AITER_ENABLED -----> Check by the logic in `is_enabled()`
Example:
from vllm._aiter_ops import rocm_aiter_ops
# Check if aiter is enabled before using operations
if rocm_aiter_ops.is_enabled():
result = rocm_aiter_ops.rms_norm(x, weight, epsilon)
Operations:
- RMS normalization: rms_norm, rms_norm2d_with_add
- GEMM operations: gemm_a8w8, gemm_a8w8_blockscale
- Fused MoE: fused_moe, asm_moe_tkw1
- Routing: topk_softmax, biased_grouped_topk, grouped_topk
- MLA decode: mla_decode_fwd
- Quantization: per_tensor_quant, per_token_quant, group_fp8_quant
- Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale
"""
# Check if the env variable is set
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
# TODO: Consolidate under _LINEAR_ENABLED
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
# TODO: Consolidate under _LINEAR_ENABLED
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
# TODO: Consolidate under _LINEAR_ENABLED
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
@classmethod
def refresh_env_variables(cls):
"""
Since the environment variables are assigned when the module is imported,
This is a helper function to reload all the env variables from
the environment variables.
for example, after monkey patching the env variables in the unit test,
you can call this function to reload the env variables.
"""
cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_enabled(cls) -> bool: def is_enabled(cls) -> bool:
"""Verifies device specs and availability of aiter main env variable."""
return cls._AITER_ENABLED return cls._AITER_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_linear_enabled(cls) -> bool: def is_linear_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._LINEAR_ENABLED return cls._AITER_ENABLED and cls._LINEAR_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool: def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() return cls.is_linear_enabled()
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_rmsnorm_enabled(cls) -> bool: def is_rmsnorm_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_fused_moe_enabled(cls) -> bool: def is_fused_moe_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._FMOE_ENABLED return cls._AITER_ENABLED and cls._FMOE_ENABLED
@classmethod @classmethod
...@@ -694,25 +782,16 @@ class rocm_aiter_ops: ...@@ -694,25 +782,16 @@ class rocm_aiter_ops:
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_mla_enabled(cls) -> bool: def is_mla_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MLA_ENABLED return cls._AITER_ENABLED and cls._MLA_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_mha_enabled(cls) -> bool: def is_mha_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MHA_ENABLED return cls._AITER_ENABLED and cls._MHA_ENABLED
@classmethod
@if_aiter_supported
def is_pa_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_triton_unified_attn_enabled(cls) -> bool: def is_triton_unified_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
@classmethod @classmethod
......
...@@ -2933,6 +2933,42 @@ def cpu_gemm_wna16( ...@@ -2933,6 +2933,42 @@ def cpu_gemm_wna16(
return output return output
def cpu_prepack_moe_weight(
weight: torch.Tensor,
isa: str,
) -> torch.Tensor:
output = torch.empty_like(weight)
torch.ops._C.prepack_moe_weight(weight, output, isa)
return output
def cpu_fused_moe(
input: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
act: str,
isa: str,
) -> torch.Tensor:
output = torch.empty_like(input)
torch.ops._C.cpu_fused_moe(
output,
input,
w13,
w2,
w13_bias,
w2_bias,
topk_weights,
topk_ids,
act,
isa,
)
return output
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn") @register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
......
...@@ -201,8 +201,8 @@ _MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {} ...@@ -201,8 +201,8 @@ _MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
def register_backend( def register_backend(
backend: AttentionBackendEnum | MambaAttentionBackendEnum, backend: AttentionBackendEnum | MambaAttentionBackendEnum,
is_mamba: bool = False,
class_path: str | None = None, class_path: str | None = None,
is_mamba: bool = False,
) -> Callable[[type], type]: ) -> Callable[[type], type]:
"""Register or override a backend implementation. """Register or override a backend implementation.
......
...@@ -2,12 +2,10 @@ ...@@ -2,12 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer.""" """Attention layer."""
import functools
from typing import cast from typing import cast
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
...@@ -16,13 +14,10 @@ from vllm.attention.backends.abstract import ( ...@@ -16,13 +14,10 @@ from vllm.attention.backends.abstract import (
MLAAttentionImpl, MLAAttentionImpl,
) )
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
from vllm.config.vllm import VllmConfig from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -36,7 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -36,7 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
direct_register_custom_op, direct_register_custom_op,
...@@ -412,132 +406,6 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -412,132 +406,6 @@ class Attention(nn.Module, AttentionLayerBase):
) )
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Determine the attention backend
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.fa_version = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
and current_platform.is_cuda()
):
self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial(
self._flash_attn_varlen_func, fa_version=self.fa_version
)
logger.info_once(
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.is_flash_attn_backend:
assert self._flash_attn_varlen_func is not None
cu_seqlens_q = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
)
cu_seqlens_k = torch.arange(
0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
)
out = self._flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == AttentionBackendEnum.PALLAS:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
f"ViT attention hasn't supported {self.attn_backend} backend yet."
)
return out.reshape(bsz, q_len, -1)
class MLAAttention(nn.Module, AttentionLayerBase): class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer. """Multi-Head Latent Attention layer.
......
...@@ -4,7 +4,7 @@ import functools ...@@ -4,7 +4,7 @@ import functools
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
...@@ -51,11 +51,19 @@ def create_chunked_local_attention_backend( ...@@ -51,11 +51,19 @@ def create_chunked_local_attention_backend(
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False, fast_build: bool = False,
) -> AttentionMetadata: ):
common_attn_metadata = make_local_attention_virtual_batches( cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size attention_chunk_size, common_attn_metadata, block_size
) )
return super().build(common_prefix_len, common_attn_metadata, fast_build) metadata = super().build(common_prefix_len, cm, fast_build)
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
return metadata
def update_block_table(
self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor
):
blk_table = metadata.make_virtual_batches_block_table(blk_table)
return super().update_block_table(metadata, blk_table, slot_mapping)
attn_backend = subclass_attention_backend( attn_backend = subclass_attention_backend(
name_prefix=prefix, name_prefix=prefix,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
...@@ -10,6 +9,7 @@ from vllm.attention.ops.vit_attn_wrappers import ( ...@@ -10,6 +9,7 @@ from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper, vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper, vit_torch_sdpa_wrapper,
) )
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import MultiModalConfig from vllm.config import MultiModalConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
...@@ -18,27 +18,6 @@ from vllm.model_executor.models.vision import get_vit_attn_backend ...@@ -18,27 +18,6 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
logger = init_logger(__name__) logger = init_logger(__name__)
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum | None,
) -> Callable | None:
# At this point,
# we already have the attn_backend,
# overriding logic is done in the platform-specific implementation.
# so we don't need to override backend here.
# Just return the attn_backend and flash_attn_varlen_func.
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
# if attn_backend is TORCH_SDPA,
# it will reach here and the flash_attn_varlen_func will be None.
return flash_attn_varlen_func
@CustomOp.register("mm_encoder_attn") @CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp): class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder.""" """Multi-headed attention without any cache, used for multimodal encoder."""
...@@ -97,8 +76,8 @@ class MMEncoderAttention(CustomOp): ...@@ -97,8 +76,8 @@ class MMEncoderAttention(CustomOp):
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( self._fa_version = (
self.attn_backend, get_flash_attn_version() if self.is_flash_attn_backend else None
) )
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
...@@ -107,7 +86,7 @@ class MMEncoderAttention(CustomOp): ...@@ -107,7 +86,7 @@ class MMEncoderAttention(CustomOp):
def enabled(cls) -> bool: def enabled(cls) -> bool:
return True return True
def reshape_qkv_to_4d( def maybe_reshape_qkv_to_4d(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
...@@ -131,30 +110,6 @@ class MMEncoderAttention(CustomOp): ...@@ -131,30 +110,6 @@ class MMEncoderAttention(CustomOp):
return query, key, value return query, key, value
def reshape_qkv_to_3d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 3D tensors:
(batch_size * seq_len, num_heads, head_size)
"""
query = query.view(bsz * q_len, self.num_heads, self.head_size)
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=1)
value = torch.repeat_interleave(value, num_repeat, dim=1)
return query, key, value
def _forward_sdpa( def _forward_sdpa(
self, self,
query: torch.Tensor, query: torch.Tensor,
...@@ -162,13 +117,15 @@ class MMEncoderAttention(CustomOp): ...@@ -162,13 +117,15 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor, value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO(Isotr0py): Migrate MultiHeadAttention """Input shape:
assert cu_seqlens is not None (batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2] bsz, q_len = query.size()[:2]
kv_len = key.size(1) kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.reshape_qkv_to_4d( query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len query, key, value, bsz, q_len, kv_len
) )
...@@ -178,6 +135,8 @@ class MMEncoderAttention(CustomOp): ...@@ -178,6 +135,8 @@ class MMEncoderAttention(CustomOp):
v=value, v=value,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
) )
if is_reshaped:
output = output.view(bsz, q_len, -1)
return output return output
def _forward_fa( def _forward_fa(
...@@ -188,13 +147,21 @@ class MMEncoderAttention(CustomOp): ...@@ -188,13 +147,21 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor: ) -> torch.Tensor:
assert self.flash_attn_varlen_func is not None, ( """Input shape:
"Flash attention function is not set." (batch_size x seq_len x hidden_size) or
) (batch_size x seq_len x num_heads x head_size)
# # TODO(Isotr0py): Migrate MultiHeadAttention """
assert cu_seqlens is not None and max_seqlen is not None assert (cu_seqlens is not None and max_seqlen is not None) or (
cu_seqlens is None and max_seqlen is None
), "cu_seqlens and max_seqlen should be both set or both None."
bsz = query.shape[0] bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
output = vit_flash_attn_wrapper( output = vit_flash_attn_wrapper(
q=query, q=query,
...@@ -204,7 +171,10 @@ class MMEncoderAttention(CustomOp): ...@@ -204,7 +171,10 @@ class MMEncoderAttention(CustomOp):
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
batch_size=bsz, batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version,
) )
if is_reshaped:
output = output.view(bsz, q_len, -1)
return output return output
def forward_native( def forward_native(
......
...@@ -24,15 +24,28 @@ def flash_attn_maxseqlen_wrapper( ...@@ -24,15 +24,28 @@ def flash_attn_maxseqlen_wrapper(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int | None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
kwargs = {}
if is_rocm_aiter: if is_rocm_aiter:
from aiter import flash_attn_varlen_func from aiter import flash_attn_varlen_func
else: else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func from vllm.attention.utils.fa_utils import flash_attn_varlen_func
if not current_platform.is_rocm() and fa_version is not None:
kwargs["fa_version"] = fa_version
q_len = q.size(1)
if cu_seqlens is None:
cu_seqlens = torch.arange(
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
)
max_seqlen = q_len if max_seqlen is None else max_seqlen.item()
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q, q,
...@@ -40,10 +53,11 @@ def flash_attn_maxseqlen_wrapper( ...@@ -40,10 +53,11 @@ def flash_attn_maxseqlen_wrapper(
v, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen.item(), max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen.item(), max_seqlen_k=max_seqlen,
dropout_p=0.0, dropout_p=0.0,
causal=False, causal=False,
**kwargs,
) )
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer return context_layer
...@@ -57,6 +71,7 @@ def flash_attn_maxseqlen_wrapper_fake( ...@@ -57,6 +71,7 @@ def flash_attn_maxseqlen_wrapper_fake(
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int | None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(q) return torch.empty_like(q)
...@@ -72,23 +87,42 @@ def vit_flash_attn_wrapper( ...@@ -72,23 +87,42 @@ def vit_flash_attn_wrapper(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int | None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper( return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter q,
k,
v,
batch_size,
is_rocm_aiter,
fa_version,
cu_seqlens,
max_seqlen,
) )
def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Input shape:
(batch_size x seq_len x num_heads x head_size)
"""
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
output = einops.rearrange(output, "b h s d -> b s h d ")
return output
# TODO: Once we have a torch 2.10, we can use tensor slices # TODO: Once we have a torch 2.10, we can use tensor slices
# so we won't need to wrap this in custom ops # so we won't need to wrap this in custom ops
def torch_sdpa_wrapper( def torch_sdpa_wrapper(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
# Never remove the contiguous logic for ROCm # Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend # Without it, hallucinations occur with the backend
...@@ -97,6 +131,9 @@ def torch_sdpa_wrapper( ...@@ -97,6 +131,9 @@ def torch_sdpa_wrapper(
k = k.contiguous() k = k.contiguous()
v = v.contiguous() v = v.contiguous()
if cu_seqlens is None:
return apply_sdpa(q, k, v)
outputs = [] outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
...@@ -104,11 +141,7 @@ def torch_sdpa_wrapper( ...@@ -104,11 +141,7 @@ def torch_sdpa_wrapper(
k_chunks = torch.split(k, lens, dim=1) k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1) v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = ( output_i = apply_sdpa(q_i, k_i, v_i)
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i) outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) context_layer = torch.cat(outputs, dim=1)
return context_layer return context_layer
...@@ -134,6 +167,6 @@ def vit_torch_sdpa_wrapper( ...@@ -134,6 +167,6 @@ def vit_torch_sdpa_wrapper(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens) return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)
...@@ -79,10 +79,6 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -79,10 +79,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
if args.profile and not engine_args.profiler_config.profiler == "torch":
raise ValueError(
"The torch profiler is not enabled. Please provide profiler_config."
)
# Lazy import to avoid importing LLM when the bench command is not selected. # Lazy import to avoid importing LLM when the bench command is not selected.
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
...@@ -125,8 +121,8 @@ def main(args: argparse.Namespace): ...@@ -125,8 +121,8 @@ def main(args: argparse.Namespace):
), ),
) )
def run_to_completion(profile_dir: str | None = None): def run_to_completion(do_profile: bool = False):
if profile_dir: if do_profile:
llm.start_profile() llm.start_profile()
llm_generate() llm_generate()
llm.stop_profile() llm.stop_profile()
...@@ -139,18 +135,24 @@ def main(args: argparse.Namespace): ...@@ -139,18 +135,24 @@ def main(args: argparse.Namespace):
print("Warming up...") print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(profile_dir=None) run_to_completion(do_profile=False)
if args.profile: if args.profile:
profile_dir = engine_args.profiler_config.torch_profiler_dir profiler_config = engine_args.profiler_config
print(f"Profiling (results will be saved to '{profile_dir}')...") if profiler_config.profiler == "torch":
run_to_completion(profile_dir=profile_dir) print(
"Profiling with torch profiler (results will be saved to"
f" {profiler_config.torch_profiler_dir})..."
)
elif profiler_config.profiler == "cuda":
print("Profiling with cuda profiler ...")
run_to_completion(do_profile=True)
return return
# Benchmark. # Benchmark.
latencies = [] latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): for _ in tqdm(range(args.num_iters), desc="Bench iterations"):
latencies.append(run_to_completion(profile_dir=None)) latencies.append(run_to_completion(do_profile=False))
latencies = np.array(latencies) latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90, 99] percentages = [10, 25, 50, 75, 90, 99]
percentiles = np.percentile(latencies, percentages) percentiles = np.percentile(latencies, percentages)
......
...@@ -10,8 +10,10 @@ On the client side, run: ...@@ -10,8 +10,10 @@ On the client side, run:
vllm bench serve \ vllm bench serve \
--backend <backend or endpoint type. Default 'openai'> \ --backend <backend or endpoint type. Default 'openai'> \
--label <benchmark result label. Default using backend> \ --label <benchmark result label. Default using backend> \
--model <your_model> \ --model <your_model. Optional, defaults to first model from server> \
--dataset-name <dataset_name. Default 'random'> \ --dataset-name <dataset_name. Default 'random'> \
--input-len <general input length. Optional, maps to dataset-specific args> \
--output-len <general output length. Optional, maps to dataset-specific args> \
--request-rate <request_rate. Default inf> \ --request-rate <request_rate. Default inf> \
--num-prompts <num_prompts. Default 1000> --num-prompts <num_prompts. Default 1000>
""" """
...@@ -57,6 +59,33 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a ...@@ -57,6 +59,33 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a
) )
async def get_first_model_from_server(
base_url: str, headers: dict | None = None
) -> str:
"""Fetch the first model from the server's /v1/models endpoint."""
models_url = f"{base_url}/v1/models"
async with aiohttp.ClientSession() as session:
try:
async with session.get(models_url, headers=headers) as response:
response.raise_for_status()
data = await response.json()
if "data" in data and len(data["data"]) > 0:
return data["data"][0]["id"]
else:
raise ValueError(
f"No models found on the server at {base_url}. "
"Make sure the server is running and has models loaded."
)
except (aiohttp.ClientError, json.JSONDecodeError) as e:
raise RuntimeError(
f"Failed to fetch models from server at {models_url}. "
"Check that:\n"
"1. The server is running\n"
"2. The server URL is correct\n"
f"Error: {e}"
) from e
class TaskType(Enum): class TaskType(Enum):
GENERATION = "generation" GENERATION = "generation"
POOLING = "pooling" POOLING = "pooling"
...@@ -1025,8 +1054,26 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -1025,8 +1054,26 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
required=True, required=False,
help="Name of the model.", default=None,
help="Name of the model. If not specified, will fetch the first model "
"from the server's /v1/models endpoint.",
)
parser.add_argument(
"--input-len",
type=int,
default=None,
help="General input length for datasets. Maps to dataset-specific "
"input length arguments (e.g., --random-input-len, --sonnet-input-len). "
"If not specified, uses dataset defaults.",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="General output length for datasets. Maps to dataset-specific "
"output length arguments (e.g., --random-output-len, --sonnet-output-len). "
"If not specified, uses dataset defaults.",
) )
parser.add_argument( parser.add_argument(
"--tokenizer", "--tokenizer",
...@@ -1332,10 +1379,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: ...@@ -1332,10 +1379,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")
label = args.label label = args.label
model_id = args.model
model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode
if args.base_url is not None: if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}" api_url = f"{args.base_url}{args.endpoint}"
...@@ -1356,6 +1399,18 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: ...@@ -1356,6 +1399,18 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
else: else:
raise ValueError("Invalid header format. Please use KEY=VALUE format.") raise ValueError("Invalid header format. Please use KEY=VALUE format.")
# Fetch model from server if not specified
if args.model is None:
print("Model not specified, fetching first model from server...")
model_id = await get_first_model_from_server(base_url, headers)
print(f"Using model: {model_id}")
else:
model_id = args.model
model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else model_id
tokenizer_mode = args.tokenizer_mode
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
tokenizer_id, tokenizer_id,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
...@@ -1368,6 +1423,20 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: ...@@ -1368,6 +1423,20 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
"'--dataset-path' if required." "'--dataset-path' if required."
) )
# Map general --input-len and --output-len to all dataset-specific arguments
if args.input_len is not None:
args.random_input_len = args.input_len
args.sonnet_input_len = args.input_len
if args.output_len is not None:
args.random_output_len = args.output_len
args.sonnet_output_len = args.output_len
args.sharegpt_output_len = args.output_len
args.custom_output_len = args.output_len
args.hf_output_len = args.output_len
args.spec_bench_output_len = args.output_len
args.prefix_repetition_output_len = args.output_len
# when using random datasets, default to ignoring EOS # when using random datasets, default to ignoring EOS
# so generation runs to the requested length # so generation runs to the requested length
if ( if (
......
...@@ -346,7 +346,10 @@ def get_requests(args, tokenizer): ...@@ -346,7 +346,10 @@ def get_requests(args, tokenizer):
"output_len": args.output_len, "output_len": args.output_len,
} }
if args.dataset_path is None or args.dataset_name == "random": if args.dataset_name == "random" or (
args.dataset_path is None
and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"}
):
sample_kwargs["range_ratio"] = args.random_range_ratio sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["prefix_len"] = args.prefix_len
dataset_cls = RandomDataset dataset_cls = RandomDataset
......
...@@ -520,6 +520,7 @@ class VllmBackend: ...@@ -520,6 +520,7 @@ class VllmBackend:
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
is_encoder: bool = False,
): ):
# if the model is initialized with a non-empty prefix, # if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix, # then usually it's enough to use that prefix,
...@@ -530,7 +531,7 @@ class VllmBackend: ...@@ -530,7 +531,7 @@ class VllmBackend:
self.prefix = prefix or model_tag self.prefix = prefix or model_tag
# Mark compilation for encoder. # Mark compilation for encoder.
self.is_encoder = model_is_encoder self.is_encoder = is_encoder or model_is_encoder
# Passes to run on the graph post-grad. # Passes to run on the graph post-grad.
self.pass_manager = resolve_obj_by_qualname( self.pass_manager = resolve_obj_by_qualname(
...@@ -797,7 +798,7 @@ class VllmBackend: ...@@ -797,7 +798,7 @@ class VllmBackend:
or not self.compilation_config.cudagraph_copy_inputs or not self.compilation_config.cudagraph_copy_inputs
): ):
return VllmSerializableFunction( return VllmSerializableFunction(
graph, example_inputs, self.prefix, self.split_gm graph, example_inputs, self.prefix, self.split_gm, self.is_encoder
) )
# index of tensors that have symbolic shapes (batch size) # index of tensors that have symbolic shapes (batch size)
...@@ -835,5 +836,5 @@ class VllmBackend: ...@@ -835,5 +836,5 @@ class VllmBackend:
return self.split_gm(*list_args) return self.split_gm(*list_args)
return VllmSerializableFunction( return VllmSerializableFunction(
graph, example_inputs, self.prefix, copy_and_call graph, example_inputs, self.prefix, copy_and_call, self.is_encoder
) )
...@@ -37,12 +37,15 @@ class VllmSerializableFunction(SerializableCallable): ...@@ -37,12 +37,15 @@ class VllmSerializableFunction(SerializableCallable):
serializing the Dynamo fx graph plus example inputs. serializing the Dynamo fx graph plus example inputs.
""" """
def __init__(self, graph_module, example_inputs, prefix, optimized_call): def __init__(
self, graph_module, example_inputs, prefix, optimized_call, is_encoder=False
):
assert isinstance(graph_module, torch.fx.GraphModule) assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module self.graph_module = graph_module
self.example_inputs = example_inputs self.example_inputs = example_inputs
self.prefix = prefix self.prefix = prefix
self.optimized_call = optimized_call self.optimized_call = optimized_call
self.is_encoder = is_encoder
self.shape_env = None self.shape_env = None
sym_input = next( sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
...@@ -104,8 +107,12 @@ class VllmSerializableFunction(SerializableCallable): ...@@ -104,8 +107,12 @@ class VllmSerializableFunction(SerializableCallable):
state = pickle.loads(data) state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv()) fake_mode = FakeTensorMode(shape_env=ShapeEnv())
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"]) is_encoder = state.get("is_encoder", False)
vllm_backend = VllmBackend(
get_current_vllm_config(), state["prefix"], is_encoder
)
def optimized_call(*example_inputs): def optimized_call(*example_inputs):
""" """
......
...@@ -435,7 +435,10 @@ def _support_torch_compile( ...@@ -435,7 +435,10 @@ def _support_torch_compile(
return self.aot_compiled_fn(self, *args, **kwargs) return self.aot_compiled_fn(self, *args, **kwargs)
if self.compiled: if self.compiled:
assert not envs.VLLM_USE_AOT_COMPILE assert (
not envs.VLLM_USE_AOT_COMPILE
or self.vllm_config.compilation_config.backend == "eager"
)
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
# This is the path for the first compilation. # This is the path for the first compilation.
...@@ -508,7 +511,11 @@ def _support_torch_compile( ...@@ -508,7 +511,11 @@ def _support_torch_compile(
_torch27_patch_tensor_subclasses(), _torch27_patch_tensor_subclasses(),
torch._inductor.config.patch(**inductor_config_patches), torch._inductor.config.patch(**inductor_config_patches),
): ):
if envs.VLLM_USE_AOT_COMPILE: use_aot_compile = envs.VLLM_USE_AOT_COMPILE
if self.vllm_config.compilation_config.backend == "eager":
logger.warning("Detected eager backend, disabling AOT compile.")
use_aot_compile = False
if use_aot_compile:
self.aot_compiled_fn = self.aot_compile(*args, **kwargs) self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs) output = self.aot_compiled_fn(self, *args, **kwargs)
assert aot_compilation_path is not None assert aot_compilation_path is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment