Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a810671a
"tests/vscode:/vscode.git/clone" did not exist on "444f0e3f339caba85f84c6628e1df50605b241a0"
Commit
a810671a
authored
Jan 08, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori
parents
86b5aefe
6a09612b
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1394 additions
and
327 deletions
+1394
-327
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+4
-0
tests/v1/kv_offload/test_cpu_offloading.py
tests/v1/kv_offload/test_cpu_offloading.py
+7
-8
tests/v1/metrics/test_perf_metrics.py
tests/v1/metrics/test_perf_metrics.py
+897
-0
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+12
-7
tests/v1/spec_decode/test_max_len.py
tests/v1/spec_decode/test_max_len.py
+40
-45
tests/v1/tpu/test_mha_attn.py
tests/v1/tpu/test_mha_attn.py
+3
-3
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+84
-0
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+106
-27
vllm/_custom_ops.py
vllm/_custom_ops.py
+36
-0
vllm/attention/backends/registry.py
vllm/attention/backends/registry.py
+1
-1
vllm/attention/layer.py
vllm/attention/layer.py
+0
-132
vllm/attention/layers/chunked_local_attention.py
vllm/attention/layers/chunked_local_attention.py
+12
-4
vllm/attention/layers/mm_encoder_attention.py
vllm/attention/layers/mm_encoder_attention.py
+29
-59
vllm/attention/ops/vit_attn_wrappers.py
vllm/attention/ops/vit_attn_wrappers.py
+47
-14
vllm/benchmarks/latency.py
vllm/benchmarks/latency.py
+14
-12
vllm/benchmarks/serve.py
vllm/benchmarks/serve.py
+76
-7
vllm/benchmarks/throughput.py
vllm/benchmarks/throughput.py
+4
-1
vllm/compilation/backends.py
vllm/compilation/backends.py
+4
-3
vllm/compilation/caching.py
vllm/compilation/caching.py
+9
-2
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+9
-2
No files found.
tests/v1/kv_connector/unit/utils.py
View file @
a810671a
...
...
@@ -11,6 +11,7 @@ import torch
from
vllm
import
SamplingParams
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
DeviceConfig
,
KVTransferConfig
,
...
...
@@ -94,6 +95,7 @@ def create_vllm_config(
dtype
:
str
=
"float16"
,
cache_dtype
:
str
=
"auto"
,
hf_overrides
:
dict
[
str
,
Any
]
|
None
=
None
,
attention_backend
:
str
|
None
=
None
,
)
->
VllmConfig
:
"""Initialize VllmConfig For Testing."""
model_config
=
ModelConfig
(
...
...
@@ -131,12 +133,14 @@ def create_vllm_config(
enable_permute_local_kv
=
enable_permute_local_kv
,
kv_connector_extra_config
=
kv_connector_extra_config
or
{},
)
attention_config
=
AttentionConfig
(
backend
=
attention_backend
)
return
VllmConfig
(
scheduler_config
=
scheduler_config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
kv_transfer_config
=
kv_transfer_config
,
device_config
=
DeviceConfig
(
"cpu"
),
attention_config
=
attention_config
,
)
...
...
tests/v1/kv_offload/test_cpu_offloading.py
View file @
a810671a
...
...
@@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt
from
vllm.config
import
KVEventsConfig
,
KVTransferConfig
from
vllm.distributed.kv_events
import
BlockStored
,
KVEventBatch
from
vllm.platforms
import
current_platform
from
vllm.utils.system_utils
import
set_env_var
CPU_BLOCK_SIZES
=
[
48
]
ATTN_BACKENDS
=
[
"FLASH_ATTN"
]
...
...
@@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
topic
=
"test"
,
)
with
set_env_var
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
):
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
gpu_memory_utilization
=
0.5
,
kv_events
_config
=
kv_
events
_config
,
kv_transfer_config
=
kv_transfer_config
,
)
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
gpu_memory_utilization
=
0.5
,
kv_events_config
=
kv_events_config
,
kv_transfer
_config
=
kv_
transfer
_config
,
attention_config
=
{
"backend"
:
attn_backend
}
,
)
events_endpoint
=
events_endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
subscriber
=
MockSubscriber
(
events_endpoint
,
topic
=
kv_events_config
.
topic
)
...
...
tests/v1/metrics/test_perf_metrics.py
0 → 100644
View file @
a810671a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the analytic estimators in metrics/flops.py.
"""
import
types
from
types
import
SimpleNamespace
from
transformers.models.deepseek_v3.configuration_deepseek_v3
import
DeepseekV3Config
from
transformers.models.llama4.configuration_llama4
import
(
Llama4Config
,
Llama4TextConfig
,
)
from
transformers.models.qwen3.configuration_qwen3
import
Qwen3Config
from
transformers.models.qwen3_moe.configuration_qwen3_moe
import
Qwen3MoeConfig
from
vllm.config.model
import
ModelConfig
,
get_hf_text_config
from
vllm.v1.metrics.perf
import
(
AttentionMetrics
,
BaseConfigParser
,
ExecutionContext
,
FfnMetrics
,
ModelMetrics
,
ParsedArgs
,
UnembedMetrics
,
)
class
MockModelConfig
:
"""Mock ModelConfig that implements the getter methods used by parsers."""
def
__init__
(
self
,
hf_config
,
dtype
):
self
.
hf_config
=
hf_config
self
.
hf_text_config
=
get_hf_text_config
(
hf_config
)
self
.
dtype
=
dtype
self
.
is_attention_free
=
False
def
__getattr__
(
self
,
name
):
# 1. Check if ModelConfig actually has this attribute
if
not
hasattr
(
ModelConfig
,
name
):
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
name
}
' "
f
"and neither does 'ModelConfig'."
)
# 2. Fetch the attribute from the ModelConfig CLASS
attr
=
getattr
(
ModelConfig
,
name
)
# 3. Case A: It is a @property
if
isinstance
(
attr
,
property
):
# Manually invoke the property's getter, passing 'self' (this mock instance)
return
attr
.
__get__
(
self
,
self
.
__class__
)
# 4. Case B: It is a standard method (function)
if
isinstance
(
attr
,
types
.
FunctionType
):
# Bind the function to 'self' so it acts like a method of
# this instance. This creates a bound method where 'self' is
# automatically passed as the first arg.
return
types
.
MethodType
(
attr
,
self
)
# 5. Case C: It is a class attribute / static variable
return
attr
def
create_mock_vllm_config
(
hf_config
,
model_dtype
=
"bfloat16"
,
cache_dtype
=
"auto"
,
quant_config
=
None
,
data_parallel_size
=
1
,
tensor_parallel_size
=
1
,
pipeline_parallel_size
=
1
,
enable_expert_parallel
=
False
,
)
->
SimpleNamespace
:
vllm_config
=
SimpleNamespace
()
vllm_config
.
model_config
=
MockModelConfig
(
hf_config
,
model_dtype
)
vllm_config
.
cache_config
=
SimpleNamespace
()
vllm_config
.
cache_config
.
cache_dtype
=
cache_dtype
vllm_config
.
quant_config
=
quant_config
vllm_config
.
parallel_config
=
SimpleNamespace
()
vllm_config
.
parallel_config
.
data_parallel_size
=
data_parallel_size
vllm_config
.
parallel_config
.
tensor_parallel_size
=
tensor_parallel_size
vllm_config
.
parallel_config
.
pipeline_parallel_size
=
pipeline_parallel_size
vllm_config
.
parallel_config
.
enable_expert_parallel
=
enable_expert_parallel
return
vllm_config
#### Parser Tests ####
def
test_base_config_parser
():
"""Test BaseConfigParser extracts base model attributes correctly."""
hf_config
=
Qwen3Config
(
vocab_size
=
50000
,
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_hidden_layers
=
24
,
)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
model_dtype
=
"float16"
)
parser
=
BaseConfigParser
()
args
=
ParsedArgs
()
result
=
parser
.
parse
(
args
,
vllm_config
)
assert
result
.
vocab_size
==
50000
assert
result
.
hidden_size
==
2048
assert
result
.
num_attention_heads
==
16
assert
result
.
num_hidden_layers
==
24
assert
result
.
weight_byte_size
==
2
# float16 is 2 bytes
assert
result
.
activation_byte_size
==
2
# default activation size
def
test_base_attention_config_parser_with_gqa
():
"""Test BaseAttentionConfigParser with grouped query attention."""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
num_attention_heads
=
32
,
num_key_value_heads
=
8
,
# GQA with 4:1 ratio
head_dim
=
128
,
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
AttentionMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
assert
result
.
num_key_value_heads
==
8
assert
result
.
head_dim
==
128
def
test_base_attention_config_parser_without_gqa
():
"""
Test BaseAttentionConfigParser defaults to MHA when num_key_value_heads not
specified.
"""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
num_attention_heads
=
32
,
# No num_key_value_heads specified
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
AttentionMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
# Should default to MHA (num_key_value_heads = num_attention_heads)
assert
result
.
num_key_value_heads
==
32
def
test_base_ffn_config_parser_dense
():
"""Test BaseFfnConfigParser for dense FFN."""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
FfnMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
assert
result
.
intermediate_size
==
11008
assert
result
.
num_experts
==
0
assert
result
.
num_experts_per_tok
==
0
assert
result
.
num_moe_layers
==
0
# No MoE
def
test_base_ffn_config_parser_moe
():
"""Test BaseFfnConfigParser for MoE FFN."""
hf_config
=
Qwen3MoeConfig
(
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_experts
=
64
,
num_experts_per_tok
=
8
,
moe_intermediate_size
=
14336
,
n_shared_experts
=
2
,
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
FfnMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
assert
result
.
num_experts
==
64
assert
result
.
num_experts_per_tok
==
8
assert
result
.
moe_intermediate_size
==
14336
assert
result
.
num_shared_experts
==
2
assert
result
.
num_moe_layers
==
32
# All layers are MoE by default
def
test_interleave_moe_layer_step_parser
():
"""Test InterleaveMoeLayerStepParser correctly computes MoE layer count."""
hf_config
=
Llama4Config
(
text_config
=
Llama4TextConfig
(
num_hidden_layers
=
32
,
num_local_experts
=
64
,
interleave_moe_layer_step
=
4
,
# Every 4th layer is MoE
),
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
FfnMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
assert
result
.
num_moe_layers
==
8
def
test_moe_layer_freq_parser
():
"""Test MoeLayerFreqParser correctly computes MoE layer count."""
hf_config
=
DeepseekV3Config
(
num_hidden_layers
=
30
,
n_routed_experts
=
64
,
moe_layer_freq
=
3
,
# Every 3rd layer after first_k_dense_replace
first_k_dense_replace
=
6
,
# First 6 layers are dense
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
FfnMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
# Layers >= 6 and divisible by 3: 6, 9, 12, 15, 18, 21, 24, 27
expected_moe_layers
=
len
(
[
layer
for
layer
in
range
(
30
)
if
layer
>=
6
and
layer
%
3
==
0
]
)
assert
expected_moe_layers
==
8
assert
result
.
num_moe_layers
==
expected_moe_layers
#### ComponentMetrics Tests ####
def
test_attention_metrics_scaling
():
"""Test that attention metrics scale proportionally with model dimensions."""
base_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
num_hidden_layers
=
12
,
head_dim
=
128
,
)
base_vllm_config
=
create_mock_vllm_config
(
base_hf_config
)
base_metrics
=
AttentionMetrics
.
from_vllm_config
(
base_vllm_config
)
# Test scaling with number of layers
double_layers_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
num_hidden_layers
=
24
,
# Double the layers
head_dim
=
128
,
)
double_layers_vllm_config
=
create_mock_vllm_config
(
double_layers_hf_config
)
double_layers_metrics
=
AttentionMetrics
.
from_vllm_config
(
double_layers_vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# FLOPS should double when layers double
base_flops
=
base_metrics
.
get_num_flops
(
ctx
)
double_flops
=
double_layers_metrics
.
get_num_flops
(
ctx
)
assert
double_flops
==
2
*
base_flops
# Read/write bytes should also scale proportionally
base_read
=
base_metrics
.
get_read_bytes
(
ctx
)
double_read
=
double_layers_metrics
.
get_read_bytes
(
ctx
)
assert
double_read
==
2
*
base_read
base_write
=
base_metrics
.
get_write_bytes
(
ctx
)
double_write
=
double_layers_metrics
.
get_write_bytes
(
ctx
)
assert
double_write
==
2
*
base_write
def
test_attention_metrics_grouped_query
():
"""Test attention metrics handle grouped query attention correctly."""
mha_hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
# MHA
num_hidden_layers
=
1
,
)
mha_config
=
create_mock_vllm_config
(
mha_hf_config
)
gqa_hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
num_attention_heads
=
32
,
num_key_value_heads
=
8
,
# GQA with 4:1 ratio
num_hidden_layers
=
1
,
)
gqa_config
=
create_mock_vllm_config
(
gqa_hf_config
)
mha_metrics
=
AttentionMetrics
.
from_vllm_config
(
mha_config
)
gqa_metrics
=
AttentionMetrics
.
from_vllm_config
(
gqa_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
1
,
context_len
=
1024
,
is_prefill
=
False
)
# GQA should have less KV cache reads since fewer KV heads
mha_read
=
mha_metrics
.
get_read_bytes
(
ctx
)
gqa_read
=
gqa_metrics
.
get_read_bytes
(
ctx
)
assert
gqa_read
<
mha_read
def
test_ffn_metrics_scaling
():
"""Test FFN metrics scale proportionally with model dimensions."""
base_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
)
base_vllm_config
=
create_mock_vllm_config
(
base_hf_config
)
base_metrics
=
FfnMetrics
.
from_vllm_config
(
base_vllm_config
)
# Test scaling with intermediate size
larger_ffn_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
intermediate_size
=
16384
,
# Double intermediate size
num_hidden_layers
=
12
,
)
larger_ffn_vllm_config
=
create_mock_vllm_config
(
larger_ffn_hf_config
)
larger_ffn_metrics
=
FfnMetrics
.
from_vllm_config
(
larger_ffn_vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# FLOPS should double when intermediate size doubles
base_flops
=
base_metrics
.
get_num_flops
(
ctx
)
larger_flops
=
larger_ffn_metrics
.
get_num_flops
(
ctx
)
assert
larger_flops
==
base_flops
*
2
def
test_moe_metrics_vs_dense
():
"""Test MoE metrics versus dense metrics."""
dense_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
)
dense_config
=
create_mock_vllm_config
(
dense_hf_config
)
moe_hf_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
num_experts
=
64
,
num_experts_per_tok
=
2
,
# 2 routed expert
moe_intermediate_size
=
8192
,
n_shared_experts
=
0
,
)
moe_config
=
create_mock_vllm_config
(
moe_hf_config
)
dense_metrics
=
FfnMetrics
.
from_vllm_config
(
dense_config
)
moe_metrics
=
FfnMetrics
.
from_vllm_config
(
moe_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# MoE should have different compute/memory characteristics
dense_flops
=
dense_metrics
.
get_num_flops
(
ctx
)
moe_flops
=
moe_metrics
.
get_num_flops
(
ctx
)
# 2 routed experts vs 1 dense.
assert
moe_flops
==
dense_flops
*
2
def
test_unembed_metrics_scaling
():
"""Test unembedding metrics scale with vocab size."""
small_vocab_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
vocab_size
=
32000
,
)
small_vocab_config
=
create_mock_vllm_config
(
small_vocab_hf_config
)
large_vocab_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
vocab_size
=
64000
,
# Double vocab size
)
large_vocab_config
=
create_mock_vllm_config
(
large_vocab_hf_config
)
small_vocab_metrics
=
UnembedMetrics
.
from_vllm_config
(
small_vocab_config
)
large_vocab_metrics
=
UnembedMetrics
.
from_vllm_config
(
large_vocab_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# FLOPS should double when vocab size doubles
small_flops
=
small_vocab_metrics
.
get_num_flops
(
ctx
)
large_flops
=
large_vocab_metrics
.
get_num_flops
(
ctx
)
assert
large_flops
==
2
*
small_flops
def
test_prefill_vs_decode_differences
():
"""Test that prefill and decode have different memory access patterns."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
num_hidden_layers
=
1
,
)
config
=
create_mock_vllm_config
(
hf_config
)
metrics
=
AttentionMetrics
.
from_vllm_config
(
config
)
prefill_ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
512
,
context_len
=
512
,
is_prefill
=
True
)
decode_ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
1
,
context_len
=
512
,
is_prefill
=
False
)
prefill_read
=
metrics
.
get_read_bytes
(
prefill_ctx
)
decode_read
=
metrics
.
get_read_bytes
(
decode_ctx
)
assert
prefill_read
!=
decode_read
def
test_model_metrics_aggregation
():
"""Test ModelMetrics correctly aggregates across components."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_hidden_layers
=
12
,
vocab_size
=
32000
,
intermediate_size
=
8192
,
)
config
=
create_mock_vllm_config
(
hf_config
)
model_metrics
=
ModelMetrics
(
config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Should have metrics for attention, ffn, and unembed
total_flops
=
model_metrics
.
get_num_flops
(
ctx
)
breakdown
=
model_metrics
.
get_num_flops_breakdown
(
ctx
)
# Breakdown should sum to total
assert
total_flops
==
sum
(
breakdown
.
values
())
def
test_moe_expert_activation_proportional_scaling
():
"""Test that routed expert metrics scale proportionally with num_experts_per_tok."""
base_moe_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
num_experts
=
64
,
num_experts_per_tok
=
1
,
# 1 expert per token
moe_intermediate_size
=
8192
,
n_shared_experts
=
2
,
)
double_experts_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
num_experts
=
64
,
num_experts_per_tok
=
2
,
# 2 experts per token (double)
moe_intermediate_size
=
8192
,
n_shared_experts
=
2
,
# Same shared experts
)
triple_experts_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
num_experts
=
64
,
num_experts_per_tok
=
3
,
# 3 experts per token (triple)
moe_intermediate_size
=
8192
,
n_shared_experts
=
2
,
# Same shared experts
)
base_vllm_config
=
create_mock_vllm_config
(
base_moe_config
)
double_vllm_config
=
create_mock_vllm_config
(
double_experts_config
)
triple_vllm_config
=
create_mock_vllm_config
(
triple_experts_config
)
base_metrics
=
FfnMetrics
.
from_vllm_config
(
base_vllm_config
)
double_metrics
=
FfnMetrics
.
from_vllm_config
(
double_vllm_config
)
triple_metrics
=
FfnMetrics
.
from_vllm_config
(
triple_vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Get total metrics - the key insight is that differences should be proportional
base_flops
=
base_metrics
.
get_num_flops
(
ctx
)
double_flops
=
double_metrics
.
get_num_flops
(
ctx
)
triple_flops
=
triple_metrics
.
get_num_flops
(
ctx
)
# The difference between double and base should equal one additional expert
one_expert_diff
=
double_flops
-
base_flops
# The difference between triple and base should equal two additional experts
two_expert_diff
=
triple_flops
-
base_flops
# Proportional scaling: 2 * (1 expert diff) should equal (2 expert diff)
assert
two_expert_diff
==
2
*
one_expert_diff
# Same logic applies to memory operations
base_read
=
base_metrics
.
get_read_bytes
(
ctx
)
double_read
=
double_metrics
.
get_read_bytes
(
ctx
)
triple_read
=
triple_metrics
.
get_read_bytes
(
ctx
)
one_expert_read_diff
=
double_read
-
base_read
two_expert_read_diff
=
triple_read
-
base_read
assert
two_expert_read_diff
==
2
*
one_expert_read_diff
# Same for write bytes
base_write
=
base_metrics
.
get_write_bytes
(
ctx
)
double_write
=
double_metrics
.
get_write_bytes
(
ctx
)
triple_write
=
triple_metrics
.
get_write_bytes
(
ctx
)
one_expert_write_diff
=
double_write
-
base_write
two_expert_write_diff
=
triple_write
-
base_write
assert
two_expert_write_diff
==
2
*
one_expert_write_diff
def
test_quantization_config_parser_fp8
():
"""Test quantization parsers with fp8."""
class
MockQuantConfig
:
def
get_name
(
self
):
return
"fp8"
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_hidden_layers
=
1
)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
quant_config
=
MockQuantConfig
())
attn_result
=
AttentionMetrics
.
get_parser
().
parse
(
vllm_config
)
assert
attn_result
.
weight_byte_size
==
1
# fp8
ffn_result
=
FfnMetrics
.
get_parser
().
parse
(
vllm_config
)
assert
ffn_result
.
weight_byte_size
==
1
# fp8
def
test_quantization_config_parser_mxfp4
():
"""Test quantization parsers with mxfp4."""
class
MockQuantConfig
:
def
get_name
(
self
):
return
"mxfp4"
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
1
)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
quant_config
=
MockQuantConfig
())
ffn_result
=
FfnMetrics
.
get_parser
().
parse
(
vllm_config
)
assert
ffn_result
.
weight_byte_size
==
0.5
# mxfp4
#### Per-GPU Tests ####
def
test_attention_per_gpu_with_tensor_parallelism
():
"""Test attention metrics with tensor parallelism - per_gpu vs global."""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
num_attention_heads
=
32
,
num_key_value_heads
=
8
,
num_hidden_layers
=
24
,
)
# Test with TP=4
vllm_config
=
create_mock_vllm_config
(
hf_config
,
tensor_parallel_size
=
4
)
metrics
=
AttentionMetrics
.
from_vllm_config
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
128
,
context_len
=
1024
,
is_prefill
=
True
)
# Get global and per-gpu metrics
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
# With TP=4, global flops should be 4x per-gpu flops (heads divided by 4)
assert
global_flops
==
4
*
per_gpu_flops
# Same for read/write bytes
global_read
=
metrics
.
get_read_bytes
(
ctx
,
per_gpu
=
False
)
per_gpu_read
=
metrics
.
get_read_bytes
(
ctx
,
per_gpu
=
True
)
# Reads should scale similarly (weight reads are divided by TP)
assert
global_read
>
per_gpu_read
global_write
=
metrics
.
get_write_bytes
(
ctx
,
per_gpu
=
False
)
per_gpu_write
=
metrics
.
get_write_bytes
(
ctx
,
per_gpu
=
True
)
assert
global_write
>
per_gpu_write
def
test_attention_per_gpu_with_pipeline_parallelism
():
"""Test attention metrics with pipeline parallelism - per_gpu vs global."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_hidden_layers
=
32
,
)
# Test with PP=4
vllm_config
=
create_mock_vllm_config
(
hf_config
,
pipeline_parallel_size
=
4
)
metrics
=
AttentionMetrics
.
from_vllm_config
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
False
)
# Get global and per-gpu metrics
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
# With PP=4, global flops should be 4x per-gpu flops (layers divided by 4)
assert
global_flops
==
4
*
per_gpu_flops
global_read
=
metrics
.
get_read_bytes
(
ctx
,
per_gpu
=
False
)
per_gpu_read
=
metrics
.
get_read_bytes
(
ctx
,
per_gpu
=
True
)
assert
global_read
==
4
*
per_gpu_read
def
test_ffn_per_gpu_with_tensor_parallelism
():
"""Test FFN metrics with tensor parallelism - per_gpu vs global."""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
intermediate_size
=
14336
,
num_hidden_layers
=
32
,
)
# Test with DP=2, TP=4 (ffn_tp_size will be 8)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
data_parallel_size
=
2
,
tensor_parallel_size
=
4
,
)
metrics
=
FfnMetrics
.
from_vllm_config
(
vllm_config
)
# ffn_tp_size should be dp_size * tp_size = 8 (when EP not enabled)
assert
metrics
.
ffn_tp_size
==
8
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
128
,
context_len
=
2048
,
is_prefill
=
True
)
# Get global and per-gpu metrics
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
# With ffn_tp_size=8, global should be 8x per-gpu
assert
global_flops
==
8
*
per_gpu_flops
def
test_ffn_per_gpu_with_pipeline_parallelism
():
"""Test FFN metrics with pipeline parallelism - per_gpu vs global."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
24
,
)
# Test with PP=6
vllm_config
=
create_mock_vllm_config
(
hf_config
,
pipeline_parallel_size
=
6
)
metrics
=
FfnMetrics
.
from_vllm_config
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Get global and per-gpu metrics
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
# With PP=6, global should be 6x per-gpu (layers divided by 6)
assert
global_flops
==
6
*
per_gpu_flops
def
test_moe_per_gpu_with_expert_parallelism
():
"""
Test MoE metrics with expert parallelism - verifies num_activated_experts bug fix.
"""
hf_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
24
,
num_experts
=
64
,
num_experts_per_tok
=
8
,
moe_intermediate_size
=
14336
,
n_shared_experts
=
2
,
)
# Test with DP=2, TP=4, EP enabled (ffn_ep_size will be 8)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
data_parallel_size
=
2
,
tensor_parallel_size
=
4
,
enable_expert_parallel
=
True
,
)
metrics
=
FfnMetrics
.
from_vllm_config
(
vllm_config
)
# When EP enabled, ffn_ep_size = dp_size * tp_size = 8
assert
metrics
.
ffn_ep_size
==
8
assert
metrics
.
ffn_tp_size
==
1
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Get per-gpu metrics
per_gpu_read_breakdown
=
metrics
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
=
True
)
global_read_breakdown
=
metrics
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
=
False
)
# Verify that routed expert weight reads are reasonable
# With per_gpu=True, each GPU has 64/8 = 8 experts
# T=100, E_per_gpu=8/8=1, so T*E=100 expert activations
# num_activated_experts should be min(100, 8) = 8
# Check that weight reads scale appropriately
# Global has all 64 experts, per-gpu has 8 experts
# So weight reads should reflect this difference
if
"routed_up_gate_weights"
in
per_gpu_read_breakdown
:
per_gpu_weight_reads
=
per_gpu_read_breakdown
[
"routed_up_gate_weights"
]
global_weight_reads
=
global_read_breakdown
[
"routed_up_gate_weights"
]
# The ratio should reflect the expert count difference
# This verifies the bug fix works correctly
assert
per_gpu_weight_reads
<
global_weight_reads
# Global should read more experts than per-gpu
# Exact ratio depends on num_activated_experts calculation
ratio
=
global_weight_reads
/
per_gpu_weight_reads
# Should be > 1 since global has more experts to read
assert
ratio
>
1
def
test_moe_per_gpu_expert_activation_accounting
():
"""
Test that MoE correctly accounts for expert activations with small batch sizes.
"""
hf_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
num_experts
=
64
,
num_experts_per_tok
=
8
,
moe_intermediate_size
=
14336
,
n_shared_experts
=
0
,
# No shared experts for this test
)
# Test with EP=8
vllm_config
=
create_mock_vllm_config
(
hf_config
,
data_parallel_size
=
8
,
enable_expert_parallel
=
True
,
)
metrics
=
FfnMetrics
.
from_vllm_config
(
vllm_config
)
# Small batch: T=10, E_per_gpu=8/8=1
# Each GPU: T*E = 10*1 = 10 activations
# Experts per GPU: 64/8 = 8
# So num_activated_experts should be min(10, 8) = 8
small_ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
10
,
context_len
=
512
,
is_prefill
=
True
)
small_read
=
metrics
.
get_read_bytes_breakdown
(
small_ctx
,
per_gpu
=
True
)
# Large batch: T=1000, E_per_gpu=1
# Each GPU: T*E = 1000*1 = 1000 activations
# Experts per GPU: 8
# So num_activated_experts should be min(1000, 8) = 8 (all experts activated)
large_ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
1000
,
context_len
=
512
,
is_prefill
=
True
)
large_read
=
metrics
.
get_read_bytes_breakdown
(
large_ctx
,
per_gpu
=
True
)
# Weight reads should be similar (both activate all 8 experts per GPU)
# But activation reads should differ (proportional to T*E)
if
"routed_up_gate_weights"
in
small_read
:
small_weight
=
small_read
[
"routed_up_gate_weights"
]
large_weight
=
large_read
[
"routed_up_gate_weights"
]
# Weight reads should be the same (both read all 8 experts)
assert
small_weight
==
large_weight
# But input activation reads should scale with T*E
small_input
=
small_read
[
"routed_up_gate_input"
]
large_input
=
large_read
[
"routed_up_gate_input"
]
assert
large_input
==
100
*
small_input
# 1000/10 = 100x
def
test_unembed_per_gpu_with_tensor_parallelism
():
"""Test unembed metrics with tensor parallelism - per_gpu vs global."""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
vocab_size
=
128000
,
)
# Test with TP=8
vllm_config
=
create_mock_vllm_config
(
hf_config
,
tensor_parallel_size
=
8
)
metrics
=
UnembedMetrics
.
from_vllm_config
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Get global and per-gpu metrics
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
# With TP=8, vocab is divided by 8, so global should be 8x per-gpu
assert
global_flops
==
8
*
per_gpu_flops
# For read bytes, weight reads scale with TP but input reads don't (replicated)
global_read_breakdown
=
metrics
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
=
False
)
per_gpu_read_breakdown
=
metrics
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
=
True
)
# Input reads should be the same (replicated across TP ranks)
assert
global_read_breakdown
[
"input"
]
==
per_gpu_read_breakdown
[
"input"
]
# Weight reads should scale 8x (divided by TP)
assert
global_read_breakdown
[
"weight"
]
==
8
*
per_gpu_read_breakdown
[
"weight"
]
def
test_model_metrics_per_gpu_aggregation
():
"""Test ModelMetrics correctly aggregates per_gpu metrics across components."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_hidden_layers
=
12
,
vocab_size
=
32000
,
intermediate_size
=
8192
,
)
# Test with mixed parallelism: TP=2, PP=2
vllm_config
=
create_mock_vllm_config
(
hf_config
,
tensor_parallel_size
=
2
,
pipeline_parallel_size
=
2
,
)
model_metrics
=
ModelMetrics
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Get breakdowns for both modes
per_gpu_breakdown
=
model_metrics
.
get_num_flops_breakdown
(
ctx
,
per_gpu
=
True
)
global_breakdown
=
model_metrics
.
get_num_flops_breakdown
(
ctx
,
per_gpu
=
False
)
# Verify breakdown sums match totals
per_gpu_total
=
model_metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
global_total
=
model_metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
assert
per_gpu_total
==
sum
(
per_gpu_breakdown
.
values
())
assert
global_total
==
sum
(
global_breakdown
.
values
())
# Global should be larger than per-gpu due to parallelism
assert
global_total
>
per_gpu_total
# With TP=2 and PP=2, the ratio depends on which parallelism applies to
# which component but we can verify that global is reasonably larger
ratio
=
global_total
/
per_gpu_total
assert
ratio
>
1
# Should be between PP and TP*PP depending on component mix
def
test_attention_per_gpu_heads_not_evenly_divisible
():
"""Test attention with heads not evenly divisible by TP."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
17
,
# Not divisible by 4
num_key_value_heads
=
5
,
# Not divisible by 4
num_hidden_layers
=
8
,
)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
tensor_parallel_size
=
4
)
metrics
=
AttentionMetrics
.
from_vllm_config
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
64
,
context_len
=
256
,
is_prefill
=
True
)
# Should not crash and should handle max(1, ...) correctly
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
# Both should be positive
assert
per_gpu_flops
>
0
assert
global_flops
>
0
assert
global_flops
>
per_gpu_flops
tests/v1/spec_decode/test_eagle.py
View file @
a810671a
...
...
@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
)
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
DeviceConfig
,
ModelConfig
,
...
...
@@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def
_create_proposer
(
method
:
str
,
num_speculative_tokens
:
int
,
attention_backend
:
str
|
None
=
None
,
speculative_token_tree
:
list
[
tuple
[
int
,
...]]
|
None
=
None
,
)
->
EagleProposer
:
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
...
...
@@ -70,6 +72,7 @@ def _create_proposer(
max_model_len
=
model_config
.
max_model_len
,
is_encoder_decoder
=
model_config
.
is_encoder_decoder
,
),
attention_config
=
AttentionConfig
(
backend
=
attention_backend
),
)
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
current_platform
.
device_type
)
...
...
@@ -331,8 +334,6 @@ def test_load_model(
use_distinct_lm_head
,
monkeypatch
,
):
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
"TRITON_ATTN does not support "
...
...
@@ -396,7 +397,9 @@ def test_load_model(
assert
not
isinstance
(
target_model
,
SupportsMultiModal
)
# Create proposer using the helper function
proposer
=
_create_proposer
(
method
,
num_speculative_tokens
=
8
)
proposer
=
_create_proposer
(
method
,
num_speculative_tokens
=
8
,
attention_backend
=
attn_backend
)
# Call the method under test
proposer
.
load_model
(
target_model
)
...
...
@@ -422,8 +425,6 @@ def test_load_model(
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
8
])
def
test_propose
(
method
,
attn_backend
,
num_speculative_tokens
,
monkeypatch
):
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
"TRITON_ATTN does not support "
...
...
@@ -451,7 +452,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
seq_lens
=
[
seq_len_1
,
seq_len_2
]
# Create proposer first so we can use its actual hidden_size
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
)
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
,
attention_backend
=
attn_backend
)
# Get the hidden_size from the proposer to ensure consistency
hidden_size
=
proposer
.
hidden_size
...
...
@@ -624,7 +627,9 @@ def test_propose_tree(spec_token_tree):
# Create proposer first so we can use its actual hidden_size.
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
,
speculative_token_tree
=
spec_token_tree
"eagle"
,
num_speculative_tokens
,
speculative_token_tree
=
spec_token_tree
,
)
# Get the hidden_size from the proposer to ensure consistency.
hidden_size
=
proposer
.
hidden_size
...
...
tests/v1/spec_decode/test_max_len.py
View file @
a810671a
...
...
@@ -38,53 +38,48 @@ def test_ngram_max_len(num_speculative_tokens: int):
def
test_eagle_max_len
(
monkeypatch
:
pytest
.
MonkeyPatch
,
num_speculative_tokens
:
int
,
attn_backend
:
str
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if
attn_backend
==
"ROCM_AITER_FA"
and
current_platform
.
is_rocm
():
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
if
attn_backend
==
"ROCM_AITER_FA"
and
current_platform
.
is_rocm
():
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
enforce_eager
=
True
,
# For faster initialization.
speculative_config
=
{
"method"
:
"eagle"
,
"model"
:
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
"num_speculative_tokens"
:
num_speculative_tokens
,
"max_model_len"
:
80
,
},
max_model_len
=
200
,
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
enforce_eager
=
True
,
# For faster initialization.
speculative_config
=
{
"method"
:
"eagle"
,
"model"
:
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
"num_speculative_tokens"
:
num_speculative_tokens
,
"max_model_len"
:
80
,
},
max_model_len
=
200
,
attention_config
=
{
"backend"
:
attn_backend
},
)
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
ignore_eos
=
True
)
outputs
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
for
o
in
outputs
:
assert
o
.
outputs
[
0
].
finish_reason
==
"length"
,
(
"This test is only meaningful if the output is truncated due to max length"
)
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
ignore_eos
=
True
)
outputs
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
for
o
in
outputs
:
assert
o
.
outputs
[
0
].
finish_reason
==
"length"
,
(
"This test is only meaningful if the output "
"is truncated due to max length"
)
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
structured_outputs
=
StructuredOutputsParams
(
regex
=
"^"
+
"a b c d e "
*
15
+
"$"
),
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
structured_outputs
=
StructuredOutputsParams
(
regex
=
"^"
+
"a b c d e "
*
15
+
"$"
),
)
output
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
for
o
in
output
:
assert
o
.
prompt_token_ids
is
not
None
assert
(
len
(
o
.
prompt_token_ids
)
<
80
<
len
(
o
.
prompt_token_ids
)
+
len
(
o
.
outputs
[
0
].
token_ids
)
<=
200
),
(
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
output
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
for
o
in
output
:
assert
o
.
prompt_token_ids
is
not
None
assert
(
len
(
o
.
prompt_token_ids
)
<
80
<
len
(
o
.
prompt_token_ids
)
+
len
(
o
.
outputs
[
0
].
token_ids
)
<=
200
),
(
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
assert
o
.
outputs
[
0
].
text
==
"a b c d e "
*
15
assert
o
.
outputs
[
0
].
text
==
"a b c d e "
*
15
tests/v1/tpu/test_mha_attn.py
View file @
a810671a
...
...
@@ -3,7 +3,7 @@
"""
Test:
* Tests for M
ultiHead
Attention layer
* Tests for M
MEncoder
Attention layer
"""
import
pytest
...
...
@@ -12,7 +12,7 @@ import torch_xla
import
torch_xla.core
import
torch_xla.core.xla_model
from
vllm.attention.layer
import
M
ultiHead
Attention
from
vllm.attention.layer
s.mm_encoder_attention
import
M
MEncoder
Attention
from
vllm.attention.selector
import
_cached_get_attn_backend
from
vllm.platforms
import
current_platform
...
...
@@ -69,7 +69,7 @@ def test_mha_attn_forward(
k
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
,
device
=
device
)
v
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
,
device
=
device
)
scale
=
1.0
/
head_size
**
0.5
attn
=
M
ultiHead
Attention
(
attn
=
M
MEncoder
Attention
(
num_heads
,
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
)
output
=
attn
(
q
,
k
,
v
)
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
a810671a
...
...
@@ -1110,3 +1110,87 @@ def test_hybrid_cache_integration(model_runner, dist_init):
runner
.
_update_states
(
scheduler_output
)
assert
_is_req_scheduled
(
runner
,
req_id
)
assert
_is_req_state_block_table_match
(
runner
,
req_id
)
def
test_is_uniform_decode
()
->
None
:
# Normal
assert
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
2
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
15
,
)
# Spec decoding
assert
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
5
,
uniform_decode_query_len
=
5
,
num_tokens
=
30
,
num_reqs
=
6
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
5
,
uniform_decode_query_len
=
4
,
num_tokens
=
30
,
num_reqs
=
6
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
5
,
uniform_decode_query_len
=
5
,
num_tokens
=
30
,
num_reqs
=
7
,
)
# Force uniform decode
assert
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
force_uniform_decode
=
True
,
)
assert
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
2
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
force_uniform_decode
=
True
,
)
assert
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
15
,
force_uniform_decode
=
True
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
force_uniform_decode
=
False
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
2
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
force_uniform_decode
=
False
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
15
,
force_uniform_decode
=
False
,
)
vllm/_aiter_ops.py
View file @
a810671a
...
...
@@ -24,14 +24,13 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND
=
is_aiter_found
()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if
IS_AITER_FOUND
:
from
aiter
import
dtypes
AITER_FP8_DTYPE
=
dtypes
.
fp8
def
is_aiter_found_and_supported
()
->
bool
:
if
current_platform
.
is_rocm
()
and
IS_AITER_FOUND
:
from
vllm.platforms.rocm
import
on_gfx9
return
on_gfx9
()
return
False
def
if_aiter_supported
(
func
:
Callable
)
->
Callable
:
...
...
@@ -43,17 +42,24 @@ def if_aiter_supported(func: Callable) -> Callable:
def
wrapper
(
*
args
,
**
kwargs
):
# checks the platform, device arch and aiter library existence.
if
current_platform
.
is_rocm
()
and
IS_AITER_FOUND
:
from
vllm.platforms.rocm
import
on_gfx9
if
on_gfx9
():
return
func
(
*
args
,
**
kwargs
)
if
is_aiter_found_and_supported
():
return
func
(
*
args
,
**
kwargs
)
return
None
return
wrapper
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if
is_aiter_found_and_supported
():
from
aiter
import
dtypes
AITER_FP8_DTYPE
=
dtypes
.
fp8
def
_rocm_aiter_fused_moe_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
@@ -642,48 +648,130 @@ _OPS_REGISTERED = False
class
rocm_aiter_ops
:
"""ROCm AITER operations wrapper for AMD GPU acceleration in vLLM.
This class centralizes the import and registration of AITER ops,
and provides a unified interface for checking if AITER is enabled.
Operations are only available on supported gfx9
architectures when aiter is installed.
The class uses environment variables to control which features are enabled,
allowing fine-grained control over which AITER optimizations are used.
Environment Variables:
VLLM_ROCM_USE_AITER: Main toggle for all AITER operations.
VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops.
VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations.
VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops.
VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops.
VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen.
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention.
VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply.
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM.
VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings.
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion.
VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM.
Note:
The environment variables are assigned when the module is imported,
so you can't change the environment variables after the module is imported.
This is done out of performance consideration. Accessing environment variables
is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067
so we don't want to do it repeatedly, especially in the hot path (the forward pass).
You can call the refresh_env_variables() function to reload the env variables
after monkey patching the env variables in the unit test.
Check Functions:
All check functions (is_*_enabled) are decorated with @if_aiter_supported,
which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
(3) aiter library is installed. The check function then also verifies
the corresponding environment variable is enabled.
i.e. ___
is_enabled() == current_platform.is_rocm() and | checked by
current_platform.is_on_gfx9() and | @if_aiter_supported
IS_AITER_FOUND and _______________|
cls._AITER_ENABLED -----> Check by the logic in `is_enabled()`
Example:
from vllm._aiter_ops import rocm_aiter_ops
# Check if aiter is enabled before using operations
if rocm_aiter_ops.is_enabled():
result = rocm_aiter_ops.rms_norm(x, weight, epsilon)
Operations:
- RMS normalization: rms_norm, rms_norm2d_with_add
- GEMM operations: gemm_a8w8, gemm_a8w8_blockscale
- Fused MoE: fused_moe, asm_moe_tkw1
- Routing: topk_softmax, biased_grouped_topk, grouped_topk
- MLA decode: mla_decode_fwd
- Quantization: per_tensor_quant, per_token_quant, group_fp8_quant
- Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale
"""
# Check if the env variable is set
_AITER_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER
_LINEAR_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_LINEAR
_RMSNORM_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
_FMOE_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_MOE
_MLA_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_MLA
_PG_ATTN_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_PAGED_ATTN
_MHA_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_MHA
_TRITON_UNIFIED_ATTN_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
# TODO: Consolidate under _LINEAR_ENABLED
_FP8BMM_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_FP8BMM
# TODO: Consolidate under _LINEAR_ENABLED
_FP4_GEMM_DYNAMIC_QUANT_ASM
=
envs
.
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
_TRITON_ROTARY_EMBED
=
envs
.
VLLM_ROCM_USE_AITER_TRITON_ROPE
_MOE_SHARED_EXPERTS_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
# TODO: Consolidate under _LINEAR_ENABLED
_TRITON_UNQUANT_GEMM
=
envs
.
VLLM_ROCM_USE_AITER_TRITON_GEMM
@
classmethod
def
refresh_env_variables
(
cls
):
"""
Since the environment variables are assigned when the module is imported,
This is a helper function to reload all the env variables from
the environment variables.
for example, after monkey patching the env variables in the unit test,
you can call this function to reload the env variables.
"""
cls
.
_AITER_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER
cls
.
_LINEAR_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_LINEAR
cls
.
_RMSNORM_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
cls
.
_FMOE_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_MOE
cls
.
_MLA_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_MLA
cls
.
_MHA_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_MHA
cls
.
_TRITON_UNIFIED_ATTN_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
cls
.
_FP8BMM_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_FP8BMM
cls
.
_FP4_GEMM_DYNAMIC_QUANT_ASM
=
envs
.
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
cls
.
_TRITON_ROTARY_EMBED
=
envs
.
VLLM_ROCM_USE_AITER_TRITON_ROPE
cls
.
_MOE_SHARED_EXPERTS_ENABLED
=
envs
.
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
cls
.
_TRITON_UNQUANT_GEMM
=
envs
.
VLLM_ROCM_USE_AITER_TRITON_GEMM
@
classmethod
@
if_aiter_supported
def
is_enabled
(
cls
)
->
bool
:
"""Verifies device specs and availability of aiter main env variable."""
return
cls
.
_AITER_ENABLED
@
classmethod
@
if_aiter_supported
def
is_linear_enabled
(
cls
)
->
bool
:
""" "Verifies device specs and availability of env variable."""
return
cls
.
_AITER_ENABLED
and
cls
.
_LINEAR_ENABLED
@
classmethod
@
if_aiter_supported
def
is_linear_fp8_enaled
(
cls
)
->
bool
:
""" "Verifies device specs and availability of env variable."""
return
cls
.
is_linear_enabled
()
@
classmethod
@
if_aiter_supported
def
is_rmsnorm_enabled
(
cls
)
->
bool
:
""" "Verifies device specs and availability of env variable."""
return
cls
.
_AITER_ENABLED
and
cls
.
_RMSNORM_ENABLED
@
classmethod
@
if_aiter_supported
def
is_fused_moe_enabled
(
cls
)
->
bool
:
""" "Verifies device specs and availability of env variable."""
return
cls
.
_AITER_ENABLED
and
cls
.
_FMOE_ENABLED
@
classmethod
...
...
@@ -694,25 +782,16 @@ class rocm_aiter_ops:
@
classmethod
@
if_aiter_supported
def
is_mla_enabled
(
cls
)
->
bool
:
""" "Verifies device specs and availability of env variable."""
return
cls
.
_AITER_ENABLED
and
cls
.
_MLA_ENABLED
@
classmethod
@
if_aiter_supported
def
is_mha_enabled
(
cls
)
->
bool
:
""" "Verifies device specs and availability of env variable."""
return
cls
.
_AITER_ENABLED
and
cls
.
_MHA_ENABLED
@
classmethod
@
if_aiter_supported
def
is_pa_attn_enabled
(
cls
)
->
bool
:
""" "Verifies device specs and availability of env variable."""
return
cls
.
_AITER_ENABLED
and
cls
.
_PG_ATTN_ENABLED
@
classmethod
@
if_aiter_supported
def
is_triton_unified_attn_enabled
(
cls
)
->
bool
:
""" "Verifies device specs and availability of env variable."""
return
cls
.
_AITER_ENABLED
and
cls
.
_TRITON_UNIFIED_ATTN_ENABLED
@
classmethod
...
...
vllm/_custom_ops.py
View file @
a810671a
...
...
@@ -2933,6 +2933,42 @@ def cpu_gemm_wna16(
return
output
def
cpu_prepack_moe_weight
(
weight
:
torch
.
Tensor
,
isa
:
str
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
weight
)
torch
.
ops
.
_C
.
prepack_moe_weight
(
weight
,
output
,
isa
)
return
output
def
cpu_fused_moe
(
input
:
torch
.
Tensor
,
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w13_bias
:
torch
.
Tensor
|
None
,
w2_bias
:
torch
.
Tensor
|
None
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
act
:
str
,
isa
:
str
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
input
)
torch
.
ops
.
_C
.
cpu_fused_moe
(
output
,
input
,
w13
,
w2
,
w13_bias
,
w2_bias
,
topk_weights
,
topk_ids
,
act
,
isa
,
)
return
output
if
hasattr
(
torch
.
ops
.
_qutlass_C
,
"matmul_mxf4_bf16_tn"
):
@
register_fake
(
"_qutlass_C::matmul_mxf4_bf16_tn"
)
...
...
vllm/attention/backends/registry.py
View file @
a810671a
...
...
@@ -201,8 +201,8 @@ _MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
def
register_backend
(
backend
:
AttentionBackendEnum
|
MambaAttentionBackendEnum
,
is_mamba
:
bool
=
False
,
class_path
:
str
|
None
=
None
,
is_mamba
:
bool
=
False
,
)
->
Callable
[[
type
],
type
]:
"""Register or override a backend implementation.
...
...
vllm/attention/layer.py
View file @
a810671a
...
...
@@ -2,12 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
import
functools
from
typing
import
cast
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
...
...
@@ -16,13 +14,10 @@ from vllm.attention.backends.abstract import (
MLAAttentionImpl
,
)
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.layers.mm_encoder_attention
import
maybe_get_vit_flash_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.attention.utils.kv_sharing_utils
import
validate_kv_sharing_target
from
vllm.attention.utils.kv_transfer_utils
import
maybe_transfer_kv_layer
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.config.multimodal
import
MultiModalConfig
from
vllm.config.vllm
import
VllmConfig
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
...
...
@@ -36,7 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.models.vision
import
get_vit_attn_backend
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
(
direct_register_custom_op
,
...
...
@@ -412,132 +406,6 @@ class Attention(nn.Module, AttentionLayerBase):
)
class
MultiHeadAttention
(
nn
.
Module
):
"""Multi-headed attention without any cache, used for ViT."""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
|
None
=
None
,
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix
:
str
=
""
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
scale
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
layer_name
=
prefix
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
,
(
f
"num_heads (
{
self
.
num_heads
}
) is not "
f
"divisible by num_kv_heads (
{
self
.
num_kv_heads
}
)"
)
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype
=
torch
.
get_default_dtype
()
# Determine the attention backend
attn_backend_override
=
None
if
multimodal_config
is
not
None
:
attn_backend_override
=
multimodal_config
.
mm_encoder_attn_backend
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_size
,
dtype
=
dtype
,
attn_backend_override
=
attn_backend_override
,
)
self
.
_flash_attn_varlen_func
=
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
)
self
.
is_flash_attn_backend
=
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}
self
.
fa_version
=
None
if
(
self
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
and
current_platform
.
is_cuda
()
):
self
.
fa_version
=
get_flash_attn_version
()
assert
self
.
_flash_attn_varlen_func
is
not
None
self
.
_flash_attn_varlen_func
=
functools
.
partial
(
self
.
_flash_attn_varlen_func
,
fa_version
=
self
.
fa_version
)
logger
.
info_once
(
f
"Using
{
self
.
attn_backend
}
for MultiHeadAttention in multimodal encoder."
)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz
,
q_len
=
query
.
size
()[:
2
]
kv_len
=
key
.
size
(
1
)
query
=
query
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
(
num_repeat
:
=
self
.
num_queries_per_kv
)
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_repeat
,
dim
=
2
)
value
=
torch
.
repeat_interleave
(
value
,
num_repeat
,
dim
=
2
)
if
self
.
is_flash_attn_backend
:
assert
self
.
_flash_attn_varlen_func
is
not
None
cu_seqlens_q
=
torch
.
arange
(
0
,
(
bsz
+
1
)
*
q_len
,
step
=
q_len
,
dtype
=
torch
.
int32
,
device
=
query
.
device
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
bsz
+
1
)
*
kv_len
,
step
=
kv_len
,
dtype
=
torch
.
int32
,
device
=
key
.
device
)
out
=
self
.
_flash_attn_varlen_func
(
query
.
flatten
(
0
,
1
),
key
.
flatten
(
0
,
1
),
value
.
flatten
(
0
,
1
),
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
q_len
,
max_seqlen_k
=
kv_len
,
softmax_scale
=
self
.
scale
,
)
elif
self
.
attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
:
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
out
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
elif
self
.
attn_backend
==
AttentionBackendEnum
.
PALLAS
:
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
from
torch_xla.experimental.custom_kernel
import
flash_attention
out
=
flash_attention
(
query
,
key
,
value
,
sm_scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
else
:
# ViT attention hasn't supported this backend yet
raise
NotImplementedError
(
f
"ViT attention hasn't supported
{
self
.
attn_backend
}
backend yet."
)
return
out
.
reshape
(
bsz
,
q_len
,
-
1
)
class
MLAAttention
(
nn
.
Module
,
AttentionLayerBase
):
"""Multi-Head Latent Attention layer.
...
...
vllm/attention/layers/chunked_local_attention.py
View file @
a810671a
...
...
@@ -4,7 +4,7 @@ import functools
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
...
...
@@ -51,11 +51,19 @@ def create_chunked_local_attention_backend(
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
AttentionMetadata
:
c
ommon_attn_metadata
=
make_local_attention_virtual_batches
(
):
c
m
,
make_virtual_batches_block_table
=
make_local_attention_virtual_batches
(
attention_chunk_size
,
common_attn_metadata
,
block_size
)
return
super
().
build
(
common_prefix_len
,
common_attn_metadata
,
fast_build
)
metadata
=
super
().
build
(
common_prefix_len
,
cm
,
fast_build
)
metadata
.
make_virtual_batches_block_table
=
make_virtual_batches_block_table
return
metadata
def
update_block_table
(
self
,
metadata
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
):
blk_table
=
metadata
.
make_virtual_batches_block_table
(
blk_table
)
return
super
().
update_block_table
(
metadata
,
blk_table
,
slot_mapping
)
attn_backend
=
subclass_attention_backend
(
name_prefix
=
prefix
,
...
...
vllm/attention/layers/mm_encoder_attention.py
View file @
a810671a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
torch
...
...
@@ -10,6 +9,7 @@ from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper
,
vit_torch_sdpa_wrapper
,
)
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.config
import
MultiModalConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
...
...
@@ -18,27 +18,6 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
logger
=
init_logger
(
__name__
)
def
maybe_get_vit_flash_attn_backend
(
attn_backend
:
AttentionBackendEnum
|
None
,
)
->
Callable
|
None
:
# At this point,
# we already have the attn_backend,
# overriding logic is done in the platform-specific implementation.
# so we don't need to override backend here.
# Just return the attn_backend and flash_attn_varlen_func.
if
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
:
from
vllm.attention.utils.fa_utils
import
flash_attn_varlen_func
elif
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
else
:
flash_attn_varlen_func
=
None
# if attn_backend is TORCH_SDPA,
# it will reach here and the flash_attn_varlen_func will be None.
return
flash_attn_varlen_func
@
CustomOp
.
register
(
"mm_encoder_attn"
)
class
MMEncoderAttention
(
CustomOp
):
"""Multi-headed attention without any cache, used for multimodal encoder."""
...
...
@@ -97,8 +76,8 @@ class MMEncoderAttention(CustomOp):
AttentionBackendEnum
.
ROCM_AITER_FA
,
}
self
.
flash_attn_varlen_func
=
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
self
.
_fa_version
=
(
get_flash_attn_version
()
if
self
.
is_flash_
attn_backend
else
None
)
logger
.
info_once
(
f
"Using
{
self
.
attn_backend
}
for MMEncoderAttention."
)
...
...
@@ -107,7 +86,7 @@ class MMEncoderAttention(CustomOp):
def
enabled
(
cls
)
->
bool
:
return
True
def
reshape_qkv_to_4d
(
def
maybe_
reshape_qkv_to_4d
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
...
@@ -131,30 +110,6 @@ class MMEncoderAttention(CustomOp):
return
query
,
key
,
value
def
reshape_qkv_to_3d
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
bsz
:
int
,
q_len
:
int
,
kv_len
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Reshape query, key, value to 3D tensors:
(batch_size * seq_len, num_heads, head_size)
"""
query
=
query
.
view
(
bsz
*
q_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
bsz
*
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
*
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
(
num_repeat
:
=
self
.
num_queries_per_kv
)
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_repeat
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_repeat
,
dim
=
1
)
return
query
,
key
,
value
def
_forward_sdpa
(
self
,
query
:
torch
.
Tensor
,
...
...
@@ -162,13 +117,15 @@ class MMEncoderAttention(CustomOp):
value
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# TODO(Isotr0py): Migrate MultiHeadAttention
assert
cu_seqlens
is
not
None
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz
,
q_len
=
query
.
size
()[:
2
]
kv_len
=
key
.
size
(
1
)
is_reshaped
=
query
.
dim
()
!=
4
query
,
key
,
value
=
self
.
reshape_qkv_to_4d
(
query
,
key
,
value
=
self
.
maybe_
reshape_qkv_to_4d
(
query
,
key
,
value
,
bsz
,
q_len
,
kv_len
)
...
...
@@ -178,6 +135,8 @@ class MMEncoderAttention(CustomOp):
v
=
value
,
cu_seqlens
=
cu_seqlens
,
)
if
is_reshaped
:
output
=
output
.
view
(
bsz
,
q_len
,
-
1
)
return
output
def
_forward_fa
(
...
...
@@ -188,13 +147,21 @@ class MMEncoderAttention(CustomOp):
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
# Only used for Flash Attention
)
->
torch
.
Tensor
:
assert
self
.
flash_attn_varlen_func
is
not
None
,
(
"Flash attention function is not set."
)
# # TODO(Isotr0py): Migrate MultiHeadAttention
assert
cu_seqlens
is
not
None
and
max_seqlen
is
not
None
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
assert
(
cu_seqlens
is
not
None
and
max_seqlen
is
not
None
)
or
(
cu_seqlens
is
None
and
max_seqlen
is
None
),
"cu_seqlens and max_seqlen should be both set or both None."
bsz
=
query
.
shape
[
0
]
bsz
,
q_len
=
query
.
size
()[:
2
]
kv_len
=
key
.
size
(
1
)
is_reshaped
=
query
.
dim
()
!=
4
query
,
key
,
value
=
self
.
maybe_reshape_qkv_to_4d
(
query
,
key
,
value
,
bsz
,
q_len
,
kv_len
)
output
=
vit_flash_attn_wrapper
(
q
=
query
,
...
...
@@ -204,7 +171,10 @@ class MMEncoderAttention(CustomOp):
max_seqlen
=
max_seqlen
,
batch_size
=
bsz
,
is_rocm_aiter
=
(
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
),
fa_version
=
self
.
_fa_version
,
)
if
is_reshaped
:
output
=
output
.
view
(
bsz
,
q_len
,
-
1
)
return
output
def
forward_native
(
...
...
vllm/attention/ops/vit_attn_wrappers.py
View file @
a810671a
...
...
@@ -24,15 +24,28 @@ def flash_attn_maxseqlen_wrapper(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
batch_size
:
int
,
is_rocm_aiter
:
bool
,
fa_version
:
int
|
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
kwargs
=
{}
if
is_rocm_aiter
:
from
aiter
import
flash_attn_varlen_func
else
:
from
vllm.attention.utils.fa_utils
import
flash_attn_varlen_func
if
not
current_platform
.
is_rocm
()
and
fa_version
is
not
None
:
kwargs
[
"fa_version"
]
=
fa_version
q_len
=
q
.
size
(
1
)
if
cu_seqlens
is
None
:
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
q_len
,
step
=
q_len
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
max_seqlen
=
q_len
if
max_seqlen
is
None
else
max_seqlen
.
item
()
q
,
k
,
v
=
(
einops
.
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
output
=
flash_attn_varlen_func
(
q
,
...
...
@@ -40,10 +53,11 @@ def flash_attn_maxseqlen_wrapper(
v
,
cu_seqlens_q
=
cu_seqlens
,
cu_seqlens_k
=
cu_seqlens
,
max_seqlen_q
=
max_seqlen
.
item
()
,
max_seqlen_k
=
max_seqlen
.
item
()
,
max_seqlen_q
=
max_seqlen
,
max_seqlen_k
=
max_seqlen
,
dropout_p
=
0.0
,
causal
=
False
,
**
kwargs
,
)
context_layer
=
einops
.
rearrange
(
output
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
context_layer
...
...
@@ -57,6 +71,7 @@ def flash_attn_maxseqlen_wrapper_fake(
max_seqlen
:
torch
.
Tensor
,
batch_size
:
int
,
is_rocm_aiter
:
bool
,
fa_version
:
int
|
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
)
...
...
@@ -72,23 +87,42 @@ def vit_flash_attn_wrapper(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
batch_size
:
int
,
is_rocm_aiter
:
bool
,
fa_version
:
int
|
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
flash_attn_maxseqlen_wrapper
(
q
,
k
,
v
,
cu_seqlens
,
max_seqlen
,
batch_size
,
is_rocm_aiter
q
,
k
,
v
,
batch_size
,
is_rocm_aiter
,
fa_version
,
cu_seqlens
,
max_seqlen
,
)
def
apply_sdpa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Input shape:
(batch_size x seq_len x num_heads x head_size)
"""
q
,
k
,
v
=
(
einops
.
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
])
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
dropout_p
=
0.0
)
output
=
einops
.
rearrange
(
output
,
"b h s d -> b s h d "
)
return
output
# TODO: Once we have a torch 2.10, we can use tensor slices
# so we won't need to wrap this in custom ops
def
torch_sdpa_wrapper
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
...
...
@@ -97,6 +131,9 @@ def torch_sdpa_wrapper(
k
=
k
.
contiguous
()
v
=
v
.
contiguous
()
if
cu_seqlens
is
None
:
return
apply_sdpa
(
q
,
k
,
v
)
outputs
=
[]
lens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
...
...
@@ -104,11 +141,7 @@ def torch_sdpa_wrapper(
k_chunks
=
torch
.
split
(
k
,
lens
,
dim
=
1
)
v_chunks
=
torch
.
split
(
v
,
lens
,
dim
=
1
)
for
q_i
,
k_i
,
v_i
in
zip
(
q_chunks
,
k_chunks
,
v_chunks
):
q_i
,
k_i
,
v_i
=
(
einops
.
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q_i
,
k_i
,
v_i
]
)
output_i
=
F
.
scaled_dot_product_attention
(
q_i
,
k_i
,
v_i
,
dropout_p
=
0.0
)
output_i
=
einops
.
rearrange
(
output_i
,
"b h s d -> b s h d "
)
output_i
=
apply_sdpa
(
q_i
,
k_i
,
v_i
)
outputs
.
append
(
output_i
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
return
context_layer
...
...
@@ -134,6 +167,6 @@ def vit_torch_sdpa_wrapper(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
torch_sdpa_wrapper
(
q
,
k
,
v
,
cu_seqlens
)
vllm/benchmarks/latency.py
View file @
a810671a
...
...
@@ -79,10 +79,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
def
main
(
args
:
argparse
.
Namespace
):
engine_args
=
EngineArgs
.
from_cli_args
(
args
)
if
args
.
profile
and
not
engine_args
.
profiler_config
.
profiler
==
"torch"
:
raise
ValueError
(
"The torch profiler is not enabled. Please provide profiler_config."
)
# Lazy import to avoid importing LLM when the bench command is not selected.
from
vllm
import
LLM
,
SamplingParams
...
...
@@ -125,8 +121,8 @@ def main(args: argparse.Namespace):
),
)
def
run_to_completion
(
profile
_dir
:
str
|
None
=
Non
e
):
if
profile
_dir
:
def
run_to_completion
(
do_
profile
:
bool
=
Fals
e
):
if
do_
profile
:
llm
.
start_profile
()
llm_generate
()
llm
.
stop_profile
()
...
...
@@ -139,18 +135,24 @@ def main(args: argparse.Namespace):
print
(
"Warming up..."
)
for
_
in
tqdm
(
range
(
args
.
num_iters_warmup
),
desc
=
"Warmup iterations"
):
run_to_completion
(
profile
_dir
=
Non
e
)
run_to_completion
(
do_
profile
=
Fals
e
)
if
args
.
profile
:
profile_dir
=
engine_args
.
profiler_config
.
torch_profiler_dir
print
(
f
"Profiling (results will be saved to '
{
profile_dir
}
')..."
)
run_to_completion
(
profile_dir
=
profile_dir
)
profiler_config
=
engine_args
.
profiler_config
if
profiler_config
.
profiler
==
"torch"
:
print
(
"Profiling with torch profiler (results will be saved to"
f
"
{
profiler_config
.
torch_profiler_dir
}
)..."
)
elif
profiler_config
.
profiler
==
"cuda"
:
print
(
"Profiling with cuda profiler ..."
)
run_to_completion
(
do_profile
=
True
)
return
# Benchmark.
latencies
=
[]
for
_
in
tqdm
(
range
(
args
.
num_iters
),
desc
=
"
Profiling
iterations"
):
latencies
.
append
(
run_to_completion
(
profile
_dir
=
Non
e
))
for
_
in
tqdm
(
range
(
args
.
num_iters
),
desc
=
"
Bench
iterations"
):
latencies
.
append
(
run_to_completion
(
do_
profile
=
Fals
e
))
latencies
=
np
.
array
(
latencies
)
percentages
=
[
10
,
25
,
50
,
75
,
90
,
99
]
percentiles
=
np
.
percentile
(
latencies
,
percentages
)
...
...
vllm/benchmarks/serve.py
View file @
a810671a
...
...
@@ -10,8 +10,10 @@ On the client side, run:
vllm bench serve \
--backend <backend or endpoint type. Default 'openai'> \
--label <benchmark result label. Default using backend> \
--model <your_model> \
--model <your_model
. Optional, defaults to first model from server
> \
--dataset-name <dataset_name. Default 'random'> \
--input-len <general input length. Optional, maps to dataset-specific args> \
--output-len <general output length. Optional, maps to dataset-specific args> \
--request-rate <request_rate. Default inf> \
--num-prompts <num_prompts. Default 1000>
"""
...
...
@@ -57,6 +59,33 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a
)
async
def
get_first_model_from_server
(
base_url
:
str
,
headers
:
dict
|
None
=
None
)
->
str
:
"""Fetch the first model from the server's /v1/models endpoint."""
models_url
=
f
"
{
base_url
}
/v1/models"
async
with
aiohttp
.
ClientSession
()
as
session
:
try
:
async
with
session
.
get
(
models_url
,
headers
=
headers
)
as
response
:
response
.
raise_for_status
()
data
=
await
response
.
json
()
if
"data"
in
data
and
len
(
data
[
"data"
])
>
0
:
return
data
[
"data"
][
0
][
"id"
]
else
:
raise
ValueError
(
f
"No models found on the server at
{
base_url
}
. "
"Make sure the server is running and has models loaded."
)
except
(
aiohttp
.
ClientError
,
json
.
JSONDecodeError
)
as
e
:
raise
RuntimeError
(
f
"Failed to fetch models from server at
{
models_url
}
. "
"Check that:
\n
"
"1. The server is running
\n
"
"2. The server URL is correct
\n
"
f
"Error:
{
e
}
"
)
from
e
class
TaskType
(
Enum
):
GENERATION
=
"generation"
POOLING
=
"pooling"
...
...
@@ -1025,8 +1054,26 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser
.
add_argument
(
"--model"
,
type
=
str
,
required
=
True
,
help
=
"Name of the model."
,
required
=
False
,
default
=
None
,
help
=
"Name of the model. If not specified, will fetch the first model "
"from the server's /v1/models endpoint."
,
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
None
,
help
=
"General input length for datasets. Maps to dataset-specific "
"input length arguments (e.g., --random-input-len, --sonnet-input-len). "
"If not specified, uses dataset defaults."
,
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
None
,
help
=
"General output length for datasets. Maps to dataset-specific "
"output length arguments (e.g., --random-output-len, --sonnet-output-len). "
"If not specified, uses dataset defaults."
,
)
parser
.
add_argument
(
"--tokenizer"
,
...
...
@@ -1332,10 +1379,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
raise
ValueError
(
"For exponential ramp-up, the start RPS cannot be 0."
)
label
=
args
.
label
model_id
=
args
.
model
model_name
=
args
.
served_model_name
tokenizer_id
=
args
.
tokenizer
if
args
.
tokenizer
is
not
None
else
args
.
model
tokenizer_mode
=
args
.
tokenizer_mode
if
args
.
base_url
is
not
None
:
api_url
=
f
"
{
args
.
base_url
}{
args
.
endpoint
}
"
...
...
@@ -1356,6 +1399,18 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
else
:
raise
ValueError
(
"Invalid header format. Please use KEY=VALUE format."
)
# Fetch model from server if not specified
if
args
.
model
is
None
:
print
(
"Model not specified, fetching first model from server..."
)
model_id
=
await
get_first_model_from_server
(
base_url
,
headers
)
print
(
f
"Using model:
{
model_id
}
"
)
else
:
model_id
=
args
.
model
model_name
=
args
.
served_model_name
tokenizer_id
=
args
.
tokenizer
if
args
.
tokenizer
is
not
None
else
model_id
tokenizer_mode
=
args
.
tokenizer_mode
tokenizer
=
get_tokenizer
(
tokenizer_id
,
tokenizer_mode
=
tokenizer_mode
,
...
...
@@ -1368,6 +1423,20 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
"'--dataset-path' if required."
)
# Map general --input-len and --output-len to all dataset-specific arguments
if
args
.
input_len
is
not
None
:
args
.
random_input_len
=
args
.
input_len
args
.
sonnet_input_len
=
args
.
input_len
if
args
.
output_len
is
not
None
:
args
.
random_output_len
=
args
.
output_len
args
.
sonnet_output_len
=
args
.
output_len
args
.
sharegpt_output_len
=
args
.
output_len
args
.
custom_output_len
=
args
.
output_len
args
.
hf_output_len
=
args
.
output_len
args
.
spec_bench_output_len
=
args
.
output_len
args
.
prefix_repetition_output_len
=
args
.
output_len
# when using random datasets, default to ignoring EOS
# so generation runs to the requested length
if
(
...
...
vllm/benchmarks/throughput.py
View file @
a810671a
...
...
@@ -346,7 +346,10 @@ def get_requests(args, tokenizer):
"output_len"
:
args
.
output_len
,
}
if
args
.
dataset_path
is
None
or
args
.
dataset_name
==
"random"
:
if
args
.
dataset_name
==
"random"
or
(
args
.
dataset_path
is
None
and
args
.
dataset_name
not
in
{
"prefix_repetition"
,
"random-mm"
,
"random-rerank"
}
):
sample_kwargs
[
"range_ratio"
]
=
args
.
random_range_ratio
sample_kwargs
[
"prefix_len"
]
=
args
.
prefix_len
dataset_cls
=
RandomDataset
...
...
vllm/compilation/backends.py
View file @
a810671a
...
...
@@ -520,6 +520,7 @@ class VllmBackend:
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
is_encoder
:
bool
=
False
,
):
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
...
...
@@ -530,7 +531,7 @@ class VllmBackend:
self
.
prefix
=
prefix
or
model_tag
# Mark compilation for encoder.
self
.
is_encoder
=
model_is_encoder
self
.
is_encoder
=
is_encoder
or
model_is_encoder
# Passes to run on the graph post-grad.
self
.
pass_manager
=
resolve_obj_by_qualname
(
...
...
@@ -797,7 +798,7 @@ class VllmBackend:
or
not
self
.
compilation_config
.
cudagraph_copy_inputs
):
return
VllmSerializableFunction
(
graph
,
example_inputs
,
self
.
prefix
,
self
.
split_gm
graph
,
example_inputs
,
self
.
prefix
,
self
.
split_gm
,
self
.
is_encoder
)
# index of tensors that have symbolic shapes (batch size)
...
...
@@ -835,5 +836,5 @@ class VllmBackend:
return
self
.
split_gm
(
*
list_args
)
return
VllmSerializableFunction
(
graph
,
example_inputs
,
self
.
prefix
,
copy_and_call
graph
,
example_inputs
,
self
.
prefix
,
copy_and_call
,
self
.
is_encoder
)
vllm/compilation/caching.py
View file @
a810671a
...
...
@@ -37,12 +37,15 @@ class VllmSerializableFunction(SerializableCallable):
serializing the Dynamo fx graph plus example inputs.
"""
def
__init__
(
self
,
graph_module
,
example_inputs
,
prefix
,
optimized_call
):
def
__init__
(
self
,
graph_module
,
example_inputs
,
prefix
,
optimized_call
,
is_encoder
=
False
):
assert
isinstance
(
graph_module
,
torch
.
fx
.
GraphModule
)
self
.
graph_module
=
graph_module
self
.
example_inputs
=
example_inputs
self
.
prefix
=
prefix
self
.
optimized_call
=
optimized_call
self
.
is_encoder
=
is_encoder
self
.
shape_env
=
None
sym_input
=
next
(
(
i
for
i
in
self
.
example_inputs
if
isinstance
(
i
,
torch
.
SymInt
)),
None
...
...
@@ -104,8 +107,12 @@ class VllmSerializableFunction(SerializableCallable):
state
=
pickle
.
loads
(
data
)
fake_mode
=
FakeTensorMode
(
shape_env
=
ShapeEnv
())
state
[
"graph_module"
]
=
GraphPickler
.
loads
(
state
[
"graph_module"
],
fake_mode
)
state
[
"graph_module"
].
recompile
()
state
[
"example_inputs"
]
=
GraphPickler
.
loads
(
state
[
"example_inputs"
],
fake_mode
)
vllm_backend
=
VllmBackend
(
get_current_vllm_config
(),
state
[
"prefix"
])
is_encoder
=
state
.
get
(
"is_encoder"
,
False
)
vllm_backend
=
VllmBackend
(
get_current_vllm_config
(),
state
[
"prefix"
],
is_encoder
)
def
optimized_call
(
*
example_inputs
):
"""
...
...
vllm/compilation/decorators.py
View file @
a810671a
...
...
@@ -435,7 +435,10 @@ def _support_torch_compile(
return
self
.
aot_compiled_fn
(
self
,
*
args
,
**
kwargs
)
if
self
.
compiled
:
assert
not
envs
.
VLLM_USE_AOT_COMPILE
assert
(
not
envs
.
VLLM_USE_AOT_COMPILE
or
self
.
vllm_config
.
compilation_config
.
backend
==
"eager"
)
return
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
# This is the path for the first compilation.
...
...
@@ -508,7 +511,11 @@ def _support_torch_compile(
_torch27_patch_tensor_subclasses
(),
torch
.
_inductor
.
config
.
patch
(
**
inductor_config_patches
),
):
if
envs
.
VLLM_USE_AOT_COMPILE
:
use_aot_compile
=
envs
.
VLLM_USE_AOT_COMPILE
if
self
.
vllm_config
.
compilation_config
.
backend
==
"eager"
:
logger
.
warning
(
"Detected eager backend, disabling AOT compile."
)
use_aot_compile
=
False
if
use_aot_compile
:
self
.
aot_compiled_fn
=
self
.
aot_compile
(
*
args
,
**
kwargs
)
output
=
self
.
aot_compiled_fn
(
self
,
*
args
,
**
kwargs
)
assert
aot_compilation_path
is
not
None
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment