Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3f8d5dd
Commit
a3f8d5dd
authored
Dec 17, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori
parents
8d75f22e
f34eca5f
Changes
499
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
756 additions
and
487 deletions
+756
-487
vllm/attention/ops/triton_unified_attention.py
vllm/attention/ops/triton_unified_attention.py
+30
-39
vllm/attention/ops/vit_attn_wrappers.py
vllm/attention/ops/vit_attn_wrappers.py
+11
-8
vllm/attention/selector.py
vllm/attention/selector.py
+39
-48
vllm/benchmarks/serve.py
vllm/benchmarks/serve.py
+4
-2
vllm/benchmarks/startup.py
vllm/benchmarks/startup.py
+326
-0
vllm/compilation/backends.py
vllm/compilation/backends.py
+29
-2
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+22
-11
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+52
-48
vllm/compilation/matcher_utils.py
vllm/compilation/matcher_utils.py
+13
-7
vllm/compilation/piecewise_backend.py
vllm/compilation/piecewise_backend.py
+1
-6
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+10
-8
vllm/config/compilation.py
vllm/config/compilation.py
+12
-87
vllm/config/kv_transfer.py
vllm/config/kv_transfer.py
+5
-0
vllm/config/model.py
vllm/config/model.py
+45
-138
vllm/config/parallel.py
vllm/config/parallel.py
+0
-5
vllm/config/pooler.py
vllm/config/pooler.py
+4
-2
vllm/config/scheduler.py
vllm/config/scheduler.py
+3
-1
vllm/config/utils.py
vllm/config/utils.py
+16
-2
vllm/config/vllm.py
vllm/config/vllm.py
+90
-73
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+44
-0
No files found.
vllm/attention/ops/triton_unified_attention.py
View file @
a3f8d5dd
...
...
@@ -355,7 +355,7 @@ def kernel_unified_attention_2d(
@
triton
.
jit
def
kernel_unified_attention_3d
(
segm_output_ptr
,
# [num_tokens, num_query_heads, num_segments, head_size]
# [num_tokens, num_query_heads, num_segments, head_size
_padded
]
segm_max_ptr
,
# [num_tokens, num_query_heads, num_segments]
segm_expsum_ptr
,
# [num_tokens, num_query_heads, num_segments]
query_ptr
,
# [num_tokens, num_query_heads, head_size]
...
...
@@ -749,6 +749,11 @@ def unified_attention(
q_descale
,
k_descale
,
v_descale
,
seq_threshold_3D
=
None
,
num_par_softmax_segments
=
None
,
softmax_segm_output
=
None
,
softmax_segm_max
=
None
,
softmax_segm_expsum
=
None
,
alibi_slopes
=
None
,
output_scale
=
None
,
qq_bias
=
None
,
...
...
@@ -793,8 +798,19 @@ def unified_attention(
TILE_SIZE_PREFILL
=
32
TILE_SIZE_DECODE
=
16
if
q
.
element_size
()
>=
2
else
32
# if batch contains a prefill
if
max_seqlen_q
>
1
or
total_num_q_blocks
*
num_kv_heads
>
128
:
# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
# 2. The batch includes at least one prefill request, or
# 3. The number of sequences exceeds the configured threshold
if
(
seq_threshold_3D
is
None
or
num_par_softmax_segments
is
None
or
softmax_segm_output
is
None
or
softmax_segm_max
is
None
or
softmax_segm_expsum
is
None
or
max_seqlen_q
>
1
or
num_seqs
>
seq_threshold_3D
):
kernel_unified_attention_2d
[
(
total_num_q_blocks
,
...
...
@@ -847,37 +863,12 @@ def unified_attention(
USE_FP8
=
output_scale
is
not
None
,
)
else
:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
# value that showed good performance in tests
NUM_SEGMENTS
=
16
segm_output
=
torch
.
empty
(
q
.
shape
[
0
],
num_query_heads
,
NUM_SEGMENTS
,
triton
.
next_power_of_2
(
head_size
),
dtype
=
torch
.
float32
,
device
=
q
.
device
,
)
segm_max
=
torch
.
empty
(
q
.
shape
[
0
],
num_query_heads
,
NUM_SEGMENTS
,
dtype
=
torch
.
float32
,
device
=
q
.
device
,
)
segm_expsum
=
torch
.
empty
(
q
.
shape
[
0
],
num_query_heads
,
NUM_SEGMENTS
,
dtype
=
torch
.
float32
,
device
=
q
.
device
,
)
kernel_unified_attention_3d
[(
total_num_q_blocks
,
num_kv_heads
,
NUM_SEGMENTS
)](
segm_output_ptr
=
segm_output
,
segm_max_ptr
=
segm_max
,
segm_expsum_ptr
=
segm_expsum
,
kernel_unified_attention_3d
[
(
total_num_q_blocks
,
num_kv_heads
,
num_par_softmax_segments
)
](
segm_output_ptr
=
softmax_segm_output
,
segm_max_ptr
=
softmax_segm_max
,
segm_expsum_ptr
=
softmax_segm_expsum
,
query_ptr
=
q
,
key_cache_ptr
=
k
,
value_cache_ptr
=
v
,
...
...
@@ -917,13 +908,13 @@ def unified_attention(
BLOCK_Q
=
BLOCK_Q
,
num_seqs
=
num_seqs
,
BLOCK_M
=
BLOCK_M
,
NUM_SEGMENTS_PER_SEQ
=
NUM_SEGMENTS
,
NUM_SEGMENTS_PER_SEQ
=
num_par_softmax_segments
,
)
reduce_segments
[(
q
.
shape
[
0
],
num_query_heads
)](
output_ptr
=
out
,
segm_output_ptr
=
segm_output
,
segm_max_ptr
=
segm_max
,
segm_expsum_ptr
=
segm_expsum
,
segm_output_ptr
=
softmax_
segm_output
,
segm_max_ptr
=
softmax_
segm_max
,
segm_expsum_ptr
=
softmax_
segm_expsum
,
seq_lens_ptr
=
seqused_k
,
num_seqs
=
num_seqs
,
num_query_heads
=
num_query_heads
,
...
...
@@ -936,6 +927,6 @@ def unified_attention(
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
query_start_len_ptr
=
cu_seqlens_q
,
BLOCK_Q
=
BLOCK_Q
,
NUM_SEGMENTS_PER_SEQ
=
NUM_SEGMENTS
,
NUM_SEGMENTS_PER_SEQ
=
num_par_softmax_segments
,
USE_FP8
=
output_scale
is
not
None
,
)
vllm/attention/ops/vit_attn_wrappers.py
View file @
a3f8d5dd
...
...
@@ -16,6 +16,7 @@ import einops
import
torch
import
torch.nn.functional
as
F
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
...
...
@@ -44,9 +45,7 @@ def flash_attn_maxseqlen_wrapper(
dropout_p
=
0.0
,
causal
=
False
,
)
context_layer
=
einops
.
rearrange
(
output
,
"(b s) h d -> s b (h d)"
,
b
=
batch_size
).
contiguous
()
context_layer
=
einops
.
rearrange
(
output
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
context_layer
...
...
@@ -59,8 +58,7 @@ def flash_attn_maxseqlen_wrapper_fake(
batch_size
:
int
,
is_rocm_aiter
:
bool
,
)
->
torch
.
Tensor
:
b
,
s
,
h
,
d
=
q
.
shape
return
torch
.
empty
((
s
,
b
,
h
*
d
),
dtype
=
q
.
dtype
,
device
=
q
.
device
)
return
torch
.
empty_like
(
q
)
direct_register_custom_op
(
...
...
@@ -92,6 +90,13 @@ def torch_sdpa_wrapper(
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if
current_platform
.
is_rocm
():
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
v
=
v
.
contiguous
()
outputs
=
[]
lens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
...
...
@@ -106,7 +111,6 @@ def torch_sdpa_wrapper(
output_i
=
einops
.
rearrange
(
output_i
,
"b h s d -> b s h d "
)
outputs
.
append
(
output_i
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
context_layer
=
einops
.
rearrange
(
context_layer
,
"b s h d -> s b (h d)"
).
contiguous
()
return
context_layer
...
...
@@ -116,8 +120,7 @@ def torch_sdpa_wrapper_fake(
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
b
,
s
,
h
,
d
=
q
.
shape
return
torch
.
empty
((
s
,
b
,
h
*
d
),
dtype
=
q
.
dtype
,
device
=
q
.
device
)
return
torch
.
empty_like
(
q
)
direct_register_custom_op
(
...
...
vllm/attention/selector.py
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
inspect
from
functools
import
cache
from
typing
import
cast
,
get_args
from
typing
import
NamedTuple
,
cast
,
get_args
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionType
from
vllm.attention.backends.registry
import
(
MAMBA_TYPE_TO_BACKEND_MAP
,
MambaAttentionBackendEnum
,
...
...
@@ -19,6 +18,31 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger
=
init_logger
(
__name__
)
class
AttentionSelectorConfig
(
NamedTuple
):
head_size
:
int
dtype
:
torch
.
dtype
kv_cache_dtype
:
CacheDType
|
None
block_size
:
int
|
None
use_mla
:
bool
=
False
has_sink
:
bool
=
False
use_sparse
:
bool
=
False
use_mm_prefix
:
bool
=
False
attn_type
:
str
=
AttentionType
.
DECODER
def
__repr__
(
self
):
return
(
f
"AttentionSelectorConfig(head_size=
{
self
.
head_size
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"kv_cache_dtype=
{
self
.
kv_cache_dtype
}
, "
f
"block_size=
{
self
.
block_size
}
, "
f
"use_mla=
{
self
.
use_mla
}
, "
f
"has_sink=
{
self
.
has_sink
}
, "
f
"use_sparse=
{
self
.
use_sparse
}
, "
f
"use_mm_prefix=
{
self
.
use_mm_prefix
}
, "
f
"attn_type=
{
self
.
attn_type
}
)"
)
def
get_attn_backend
(
head_size
:
int
,
dtype
:
torch
.
dtype
,
...
...
@@ -44,8 +68,7 @@ def get_attn_backend(
vllm_config
=
get_current_vllm_config
()
backend_enum
=
vllm_config
.
attention_config
.
backend
return
_cached_get_attn_backend
(
backend
=
backend_enum
,
attn_selector_config
=
AttentionSelectorConfig
(
head_size
=
head_size
,
dtype
=
dtype
,
kv_cache_dtype
=
cast
(
CacheDType
|
None
,
kv_cache_dtype
),
...
...
@@ -54,58 +77,26 @@ def get_attn_backend(
has_sink
=
has_sink
,
use_sparse
=
use_sparse
,
use_mm_prefix
=
use_mm_prefix
,
attn_type
=
attn_type
,
attn_type
=
attn_type
or
AttentionType
.
DECODER
,
)
return
_cached_get_attn_backend
(
backend
=
backend_enum
,
attn_selector_config
=
attn_selector_config
,
)
@
cache
def
_cached_get_attn_backend
(
backend
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
CacheDType
|
None
,
block_size
:
int
|
None
,
use_mla
:
bool
=
False
,
has_sink
:
bool
=
False
,
use_sparse
:
bool
=
False
,
use_mm_prefix
:
bool
=
False
,
attn_type
:
str
|
None
=
None
,
attn_selector_config
:
AttentionSelectorConfig
,
)
->
type
[
AttentionBackend
]:
from
vllm.platforms
import
current_platform
sig
=
inspect
.
signature
(
current_platform
.
get_attn_backend_cls
)
if
"use_v1"
in
sig
.
parameters
:
logger
.
warning_once
(
"use_v1 parameter for get_attn_backend_cls is deprecated and will "
"be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
"remove it from your plugin code."
)
attention_cls
=
current_platform
.
get_attn_backend_cls
(
backend
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
True
,
# use_v1
use_mla
,
has_sink
,
use_sparse
,
use_mm_prefix
,
attn_type
,
)
else
:
attention_cls
=
current_platform
.
get_attn_backend_cls
(
backend
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_mla
,
has_sink
,
use_sparse
,
use_mm_prefix
,
attn_type
,
)
attention_cls
=
current_platform
.
get_attn_backend_cls
(
backend
,
attn_selector_config
=
attn_selector_config
,
)
if
not
attention_cls
:
raise
ValueError
(
f
"Invalid attention backend for
{
current_platform
.
device_name
}
"
...
...
vllm/benchmarks/serve.py
View file @
a3f8d5dd
...
...
@@ -235,7 +235,9 @@ async def get_request(
def
calculate_metrics_for_embeddings
(
outputs
:
list
[
RequestFuncOutput
],
dur_s
:
float
,
selected_percentiles
:
list
[
float
]
outputs
:
list
[
RequestFuncOutput
],
dur_s
:
float
,
selected_percentiles
:
list
[
float
],
)
->
EmbedBenchmarkMetrics
:
"""Calculate the metrics for the embedding requests.
...
...
@@ -788,7 +790,7 @@ async def benchmark(
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Total
T
oken throughput (tok/s):"
,
metrics
.
total_token_throughput
"Total
t
oken throughput (tok/s):"
,
metrics
.
total_token_throughput
)
)
...
...
vllm/benchmarks/startup.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark the cold and warm startup time of vLLM models.
This script measures total startup time (including model loading, compilation,
and cache operations) for both cold and warm scenarios:
- Cold startup: Fresh start with no caches (temporary cache directories)
- Warm startup: Using cached compilation and model info
"""
import
argparse
import
dataclasses
import
json
import
multiprocessing
import
os
import
shutil
import
tempfile
import
time
from
contextlib
import
contextmanager
from
typing
import
Any
import
numpy
as
np
from
tqdm
import
tqdm
from
vllm.benchmarks.lib.utils
import
(
convert_to_pytorch_benchmark_format
,
write_to_json
,
)
from
vllm.engine.arg_utils
import
EngineArgs
@
contextmanager
def
cold_startup
():
"""
Context manager to measure cold startup time:
1. Uses a temporary directory for vLLM cache to avoid any pollution
between cold startup iterations.
2. Uses inductor's fresh_cache to clear torch.compile caches.
"""
from
torch._inductor.utils
import
fresh_cache
# Use temporary directory for caching to avoid any pollution between cold startups
original_cache_root
=
os
.
environ
.
get
(
"VLLM_CACHE_ROOT"
)
temp_cache_dir
=
tempfile
.
mkdtemp
(
prefix
=
"vllm_startup_bench_cold_"
)
try
:
os
.
environ
[
"VLLM_CACHE_ROOT"
]
=
temp_cache_dir
with
fresh_cache
():
yield
finally
:
# Clean up temporary cache directory
shutil
.
rmtree
(
temp_cache_dir
,
ignore_errors
=
True
)
if
original_cache_root
:
os
.
environ
[
"VLLM_CACHE_ROOT"
]
=
original_cache_root
else
:
os
.
environ
.
pop
(
"VLLM_CACHE_ROOT"
,
None
)
def
run_startup_in_subprocess
(
engine_args_dict
,
result_queue
):
"""
Run LLM startup in a subprocess and return timing metrics via a queue.
This ensures complete isolation between iterations.
"""
try
:
# Import inside the subprocess to avoid issues with forking
from
vllm
import
LLM
from
vllm.engine.arg_utils
import
EngineArgs
engine_args
=
EngineArgs
(
**
engine_args_dict
)
# Measure total startup time
start_time
=
time
.
perf_counter
()
llm
=
LLM
(
**
dataclasses
.
asdict
(
engine_args
))
total_startup_time
=
time
.
perf_counter
()
-
start_time
# Extract compilation time if available
compilation_time
=
0.0
if
hasattr
(
llm
.
llm_engine
,
"vllm_config"
):
vllm_config
=
llm
.
llm_engine
.
vllm_config
if
(
hasattr
(
vllm_config
,
"compilation_config"
)
and
vllm_config
.
compilation_config
is
not
None
):
compilation_time
=
vllm_config
.
compilation_config
.
compilation_time
result_queue
.
put
(
{
"total_startup_time"
:
total_startup_time
,
"compilation_time"
:
compilation_time
,
}
)
except
Exception
as
e
:
result_queue
.
put
(
None
)
result_queue
.
put
(
str
(
e
))
def
save_to_pytorch_benchmark_format
(
args
:
argparse
.
Namespace
,
results
:
dict
[
str
,
Any
]
)
->
None
:
base_name
=
os
.
path
.
splitext
(
args
.
output_json
)[
0
]
cold_startup_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_cold_startup_time"
:
results
[
"avg_cold_startup_time"
],
},
extra_info
=
{
"cold_startup_times"
:
results
[
"cold_startup_times"
],
"cold_startup_percentiles"
:
results
[
"cold_startup_percentiles"
],
},
)
if
cold_startup_records
:
write_to_json
(
f
"
{
base_name
}
.cold_startup.pytorch.json"
,
cold_startup_records
)
cold_compilation_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_cold_compilation_time"
:
results
[
"avg_cold_compilation_time"
],
},
extra_info
=
{
"cold_compilation_times"
:
results
[
"cold_compilation_times"
],
"cold_compilation_percentiles"
:
results
[
"cold_compilation_percentiles"
],
},
)
if
cold_compilation_records
:
write_to_json
(
f
"
{
base_name
}
.cold_compilation.pytorch.json"
,
cold_compilation_records
)
warm_startup_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_warm_startup_time"
:
results
[
"avg_warm_startup_time"
],
},
extra_info
=
{
"warm_startup_times"
:
results
[
"warm_startup_times"
],
"warm_startup_percentiles"
:
results
[
"warm_startup_percentiles"
],
},
)
if
warm_startup_records
:
write_to_json
(
f
"
{
base_name
}
.warm_startup.pytorch.json"
,
warm_startup_records
)
warm_compilation_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_warm_compilation_time"
:
results
[
"avg_warm_compilation_time"
],
},
extra_info
=
{
"warm_compilation_times"
:
results
[
"warm_compilation_times"
],
"warm_compilation_percentiles"
:
results
[
"warm_compilation_percentiles"
],
},
)
if
warm_compilation_records
:
write_to_json
(
f
"
{
base_name
}
.warm_compilation.pytorch.json"
,
warm_compilation_records
)
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--num-iters-cold"
,
type
=
int
,
default
=
5
,
help
=
"Number of cold startup iterations."
,
)
parser
.
add_argument
(
"--num-iters-warmup"
,
type
=
int
,
default
=
3
,
help
=
"Number of warmup iterations before benchmarking warm startups."
,
)
parser
.
add_argument
(
"--num-iters-warm"
,
type
=
int
,
default
=
5
,
help
=
"Number of warm startup iterations."
,
)
parser
.
add_argument
(
"--output-json"
,
type
=
str
,
default
=
None
,
help
=
"Path to save the startup time results in JSON format."
,
)
parser
=
EngineArgs
.
add_cli_args
(
parser
)
return
parser
def
main
(
args
:
argparse
.
Namespace
):
# Set multiprocessing start method to 'spawn' for clean process isolation
# This ensures each subprocess starts fresh without inheriting state
multiprocessing
.
set_start_method
(
"spawn"
,
force
=
True
)
engine_args
=
EngineArgs
.
from_cli_args
(
args
)
def
create_llm_and_measure_startup
():
"""
Create LLM instance in a subprocess and measure startup time.
Returns timing metrics, using subprocess for complete isolation.
"""
# Convert engine_args to dictionary for pickling
engine_args_dict
=
dataclasses
.
asdict
(
engine_args
)
# Create a queue for inter-process communication
result_queue
=
multiprocessing
.
Queue
()
process
=
multiprocessing
.
Process
(
target
=
run_startup_in_subprocess
,
args
=
(
engine_args_dict
,
result_queue
,
),
)
process
.
start
()
process
.
join
()
if
not
result_queue
.
empty
():
result
=
result_queue
.
get
()
if
result
is
None
:
if
not
result_queue
.
empty
():
error_msg
=
result_queue
.
get
()
raise
RuntimeError
(
f
"Subprocess failed:
{
error_msg
}
"
)
else
:
raise
RuntimeError
(
"Subprocess failed with unknown error"
)
return
result
else
:
raise
RuntimeError
(
"Subprocess did not return a result"
)
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
print
(
"Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.
\n
"
)
print
(
"Measuring cold startup time...
\n
"
)
cold_startup_times
=
[]
cold_compilation_times
=
[]
for
i
in
tqdm
(
range
(
args
.
num_iters_cold
),
desc
=
"Cold startup iterations"
):
with
cold_startup
():
metrics
=
create_llm_and_measure_startup
()
cold_startup_times
.
append
(
metrics
[
"total_startup_time"
])
cold_compilation_times
.
append
(
metrics
[
"compilation_time"
])
# Warmup for warm startup
print
(
"
\n
Warming up for warm startup measurement...
\n
"
)
for
_
in
tqdm
(
range
(
args
.
num_iters_warmup
),
desc
=
"Warmup iterations"
):
create_llm_and_measure_startup
()
print
(
"
\n
Measuring warm startup time...
\n
"
)
warm_startup_times
=
[]
warm_compilation_times
=
[]
for
i
in
tqdm
(
range
(
args
.
num_iters_warm
),
desc
=
"Warm startup iterations"
):
metrics
=
create_llm_and_measure_startup
()
warm_startup_times
.
append
(
metrics
[
"total_startup_time"
])
warm_compilation_times
.
append
(
metrics
[
"compilation_time"
])
# Calculate statistics
cold_startup_array
=
np
.
array
(
cold_startup_times
)
cold_compilation_array
=
np
.
array
(
cold_compilation_times
)
warm_startup_array
=
np
.
array
(
warm_startup_times
)
warm_compilation_array
=
np
.
array
(
warm_compilation_times
)
avg_cold_startup
=
np
.
mean
(
cold_startup_array
)
avg_cold_compilation
=
np
.
mean
(
cold_compilation_array
)
avg_warm_startup
=
np
.
mean
(
warm_startup_array
)
avg_warm_compilation
=
np
.
mean
(
warm_compilation_array
)
percentages
=
[
10
,
25
,
50
,
75
,
90
,
99
]
cold_startup_percentiles
=
np
.
percentile
(
cold_startup_array
,
percentages
)
cold_compilation_percentiles
=
np
.
percentile
(
cold_compilation_array
,
percentages
)
warm_startup_percentiles
=
np
.
percentile
(
warm_startup_array
,
percentages
)
warm_compilation_percentiles
=
np
.
percentile
(
warm_compilation_array
,
percentages
)
print
(
"
\n
"
+
"="
*
60
)
print
(
"STARTUP TIME BENCHMARK RESULTS"
)
print
(
"="
*
60
)
# Cold startup statistics
print
(
"
\n
COLD STARTUP:"
)
print
(
f
"Avg total startup time:
{
avg_cold_startup
:.
2
f
}
seconds"
)
print
(
f
"Avg compilation time:
{
avg_cold_compilation
:.
2
f
}
seconds"
)
print
(
"Startup time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
cold_startup_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
print
(
"Compilation time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
cold_compilation_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
# Warm startup statistics
print
(
"
\n
WARM STARTUP:"
)
print
(
f
"Avg total startup time:
{
avg_warm_startup
:.
2
f
}
seconds"
)
print
(
f
"Avg compilation time:
{
avg_warm_compilation
:.
2
f
}
seconds"
)
print
(
"Startup time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
warm_startup_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
print
(
"Compilation time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
warm_compilation_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
print
(
"="
*
60
)
# Output JSON results if specified
if
args
.
output_json
:
results
=
{
"avg_cold_startup_time"
:
float
(
avg_cold_startup
),
"avg_cold_compilation_time"
:
float
(
avg_cold_compilation
),
"cold_startup_times"
:
cold_startup_times
,
"cold_compilation_times"
:
cold_compilation_times
,
"cold_startup_percentiles"
:
dict
(
zip
(
percentages
,
cold_startup_percentiles
.
tolist
())
),
"cold_compilation_percentiles"
:
dict
(
zip
(
percentages
,
cold_compilation_percentiles
.
tolist
())
),
"avg_warm_startup_time"
:
float
(
avg_warm_startup
),
"avg_warm_compilation_time"
:
float
(
avg_warm_compilation
),
"warm_startup_times"
:
warm_startup_times
,
"warm_compilation_times"
:
warm_compilation_times
,
"warm_startup_percentiles"
:
dict
(
zip
(
percentages
,
warm_startup_percentiles
.
tolist
())
),
"warm_compilation_percentiles"
:
dict
(
zip
(
percentages
,
warm_compilation_percentiles
.
tolist
())
),
}
with
open
(
args
.
output_json
,
"w"
)
as
f
:
json
.
dump
(
results
,
f
,
indent
=
4
)
save_to_pytorch_benchmark_format
(
args
,
results
)
vllm/compilation/backends.py
View file @
a3f8d5dd
...
...
@@ -141,7 +141,25 @@ class CompilerManager:
# we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
self
.
cache
=
ast
.
literal_eval
(
f
.
read
())
cache
=
ast
.
literal_eval
(
f
.
read
())
def
check_type
(
value
,
ty
):
if
not
isinstance
(
value
,
ty
):
raise
TypeError
(
f
"Expected
{
ty
}
but got
{
type
(
value
)
}
for
{
value
}
"
)
def
parse_key
(
key
:
Any
)
->
tuple
[
Range
,
int
,
str
]:
range_tuple
,
graph_index
,
compiler_name
=
key
check_type
(
graph_index
,
int
)
check_type
(
compiler_name
,
str
)
if
isinstance
(
range_tuple
,
tuple
):
start
,
end
=
range_tuple
check_type
(
start
,
int
)
check_type
(
end
,
int
)
range_tuple
=
Range
(
start
=
start
,
end
=
end
)
check_type
(
range_tuple
,
Range
)
return
range_tuple
,
graph_index
,
compiler_name
self
.
cache
=
{
parse_key
(
key
):
value
for
key
,
value
in
cache
.
items
()}
self
.
compiler
.
initialize_cache
(
cache_dir
=
cache_dir
,
disable_cache
=
disable_cache
,
prefix
=
prefix
...
...
@@ -445,21 +463,27 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag
:
str
=
"backbone"
model_is_encoder
:
bool
=
False
@
contextmanager
def
set_model_tag
(
tag
:
str
):
def
set_model_tag
(
tag
:
str
,
is_encoder
:
bool
=
False
):
"""Context manager to set the model tag."""
global
model_tag
global
model_is_encoder
assert
tag
!=
model_tag
,
(
f
"Model tag
{
tag
}
is the same as the current tag
{
model_tag
}
."
)
old_tag
=
model_tag
old_is_encoder
=
model_is_encoder
model_tag
=
tag
model_is_encoder
=
is_encoder
try
:
yield
finally
:
model_tag
=
old_tag
model_is_encoder
=
old_is_encoder
class
VllmBackend
:
...
...
@@ -505,6 +529,9 @@ class VllmBackend:
# them, e.g. backbone (default), eagle_head, etc.
self
.
prefix
=
prefix
or
model_tag
# Mark compilation for encoder.
self
.
is_encoder
=
model_is_encoder
# Passes to run on the graph post-grad.
self
.
pass_manager
=
resolve_obj_by_qualname
(
current_platform
.
get_pass_manager_cls
()
...
...
vllm/compilation/decorators.py
View file @
a3f8d5dd
...
...
@@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.utils.torch_utils
import
supports_dynamo
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
,
supports_dynamo
from
.monitor
import
start_monitoring_torch_compile
...
...
@@ -316,7 +316,13 @@ def _support_torch_compile(
def
_mark_dynamic_inputs
(
mod
,
type
,
*
args
,
**
kwargs
):
def
mark_dynamic
(
arg
,
dims
):
if
type
==
DynamicShapesType
.
UNBACKED
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dims
)
if
is_torch_equal_or_newer
(
"2.10.0.dev"
):
for
dim
in
dims
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dim
,
hint_override
=
arg
.
size
()[
dim
]
)
else
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dims
)
else
:
torch
.
_dynamo
.
mark_dynamic
(
arg
,
dims
)
...
...
@@ -350,7 +356,13 @@ def _support_torch_compile(
if
isinstance
(
arg
,
torch
.
Tensor
):
# In case dims is specified with negative indexing
dims
=
[
arg
.
ndim
+
dim
if
dim
<
0
else
dim
for
dim
in
dims
]
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dims
)
if
is_torch_equal_or_newer
(
"2.10.0.dev"
):
for
dim
in
dims
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dim
,
hint_override
=
arg
.
size
()[
dim
]
)
else
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dims
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
# torch.compiler.is_compiling() means we are inside the compilation
...
...
@@ -378,14 +390,6 @@ def _support_torch_compile(
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
# Validate that AOT compile is not used with unbacked dynamic
# shapes. aot_compile re-allocates backed symbols post dynamo!
if
ds_type
==
DynamicShapesType
.
UNBACKED
:
raise
ValueError
(
"AOT compilation is not compatible with UNBACKED dynamic shapes. "
"Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
"when VLLM_USE_AOT_COMPILE is enabled."
)
from
.caching
import
compilation_config_hash_factors
factors
:
list
[
str
]
=
compilation_config_hash_factors
(
self
.
vllm_config
)
...
...
@@ -488,6 +492,12 @@ def _support_torch_compile(
if
ds_type
==
DynamicShapesType
.
BACKED_SIZE_OBLIVIOUS
:
fx_config_patches
[
"backed_size_oblivious"
]
=
True
# Prepare inductor config patches
# assume_32bit_indexing is only available in torch 2.10.0.dev+
inductor_config_patches
=
{}
if
is_torch_equal_or_newer
(
"2.10.0.dev"
):
inductor_config_patches
[
"assume_32bit_indexing"
]
=
True
with
(
patch
.
object
(
InliningInstructionTranslator
,
"inline_call_"
,
patched_inline_call
...
...
@@ -496,6 +506,7 @@ def _support_torch_compile(
maybe_use_cudagraph_partition_wrapper
(
self
.
vllm_config
),
torch
.
fx
.
experimental
.
_config
.
patch
(
**
fx_config_patches
),
_torch27_patch_tensor_subclasses
(),
torch
.
_inductor
.
config
.
patch
(
**
inductor_config_patches
),
):
if
envs
.
VLLM_USE_AOT_COMPILE
:
self
.
aot_compiled_fn
=
self
.
aot_compile
(
*
args
,
**
kwargs
)
...
...
vllm/compilation/fusion.py
View file @
a3f8d5dd
...
...
@@ -23,17 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Quant
,
kStaticTensorScale
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_block_fp8_supported
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
(
is_deep_gemm_e8m0_used
,
should_use_deepgemm_for_fp8_linear_for_nk
,
)
from
.inductor_pass
import
enable_fake_mode
from
.matcher_utils
import
MatcherFusedAddRMSNorm
,
MatcherQuantFP8
,
MatcherRMSNorm
from
.matcher_utils
import
(
MatcherFusedAddRMSNorm
,
MatcherQuantFP8
,
MatcherRMSNorm
,
)
from
.vllm_inductor_pass
import
VllmInductorPass
,
VllmPatternMatcherPass
logger
=
init_logger
(
__name__
)
...
...
@@ -118,21 +115,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
class
RMSNormQuantPattern
:
def
__init__
(
self
,
epsilon
:
float
,
key
:
FusedRMSQuantKey
):
def
__init__
(
self
,
epsilon
:
float
,
key
:
FusedRMSQuantKey
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
self
.
epsilon
=
epsilon
self
.
quant_dtype
=
key
.
quant
.
dtype
config
=
get_current_vllm_config
()
self
.
model_dtype
=
config
.
model_config
.
dtype
if
config
.
model_config
else
None
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
using_deepgemm
=
should_use_deepgemm_for_fp8_linear_for_nk
(
self
.
model_dtype
,
config
.
model_config
.
hf_config
.
intermediate_size
,
config
.
model_config
.
hf_config
.
hidden_size
,
)
use_col_major_scales
=
using_deepgemm
or
cutlass_block_fp8_supported
()
use_e8m0
=
is_deep_gemm_e8m0_used
()
if
using_deepgemm
else
False
assert
key
in
FUSED_OPS
,
f
"unsupported fused rmsnorm+quant op for
{
key
}
"
self
.
FUSED_OP
=
FUSED_OPS
[
key
]
...
...
@@ -142,7 +136,7 @@ class RMSNormQuantPattern:
else
MatcherFusedAddRMSNorm
(
epsilon
)
)
self
.
quant_matcher
=
MatcherQuantFP8
(
key
.
quant
,
use
_col_major_scales
=
use
_col_major_scales
,
use
_e8m0
=
use
_e8m0
key
.
quant
,
has
_col_major_scales
=
has
_col_major_scales
,
is
_e8m0
=
is
_e8m0
)
...
...
@@ -260,6 +254,8 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
symmetric
=
True
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
...
...
@@ -267,7 +263,11 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant
=
QuantKey
(
dtype
=
quant_dtype
,
scale
=
scale
,
symmetric
=
symmetric
),
)
self
.
group_shape
=
group_shape
super
().
__init__
(
epsilon
,
key
)
self
.
has_col_major_scales
=
has_col_major_scales
self
.
is_e8m0
=
is_e8m0
super
().
__init__
(
epsilon
,
key
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
...
...
@@ -283,9 +283,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
result
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
scale
=
self
.
quant_matcher
.
make_scale
(
input
,
transposed
=
self
.
quant_matcher
.
use_col_major_scales
)
scale
=
self
.
quant_matcher
.
make_scale
(
input
,
self
.
has_col_major_scales
)
at
=
auto_functionalized
(
self
.
FUSED_OP
,
result
=
result
,
...
...
@@ -296,7 +294,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub
=
None
,
residual
=
residual
,
group_size
=
self
.
group_shape
[
1
],
is_scale_transposed
=
self
.
quant_matcher
.
use
_col_major_scales
,
is_scale_transposed
=
self
.
has
_col_major_scales
,
)
# result, residual, scale
...
...
@@ -318,6 +316,8 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
symmetric
=
True
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
...
...
@@ -325,7 +325,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant
=
QuantKey
(
dtype
=
quant_dtype
,
scale
=
scale
,
symmetric
=
symmetric
),
)
self
.
group_shape
=
group_shape
super
().
__init__
(
epsilon
,
key
)
super
().
__init__
(
epsilon
,
key
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
...
...
@@ -340,7 +342,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
result
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
scale
=
self
.
quant_matcher
.
make_scale
(
input
,
transposed
=
self
.
quant_matcher
.
use
_col_major_scales
input
,
transposed
=
self
.
quant_matcher
.
has
_col_major_scales
)
at
=
auto_functionalized
(
self
.
FUSED_OP
,
...
...
@@ -352,7 +354,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub
=
None
,
residual
=
None
,
group_size
=
self
.
group_shape
[
1
],
is_scale_transposed
=
self
.
quant_matcher
.
use
_col_major_scales
,
is_scale_transposed
=
self
.
quant_matcher
.
has
_col_major_scales
,
)
# result, scale
...
...
@@ -489,27 +491,6 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for
epsilon
in
[
1e-5
,
1e-6
]:
# Fuse fused_add_rms_norm + fp8 group quant
# Only register group quant patterns on CUDA where the C++ op exists
if
current_platform
.
is_cuda
():
FusedAddRMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
GroupShape
(
1
,
128
)
).
register
(
self
.
patterns
)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
GroupShape
(
1
,
128
)
).
register
(
self
.
patterns
)
FusedAddRMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
GroupShape
(
1
,
64
)
).
register
(
self
.
patterns
)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
GroupShape
(
1
,
64
)
).
register
(
self
.
patterns
)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
self
.
patterns
...
...
@@ -526,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
self
.
patterns
)
# Only register group quant patterns on CUDA where the C++ op exists
if
current_platform
.
is_cuda
():
for
group_shape
in
[
GroupShape
(
1
,
128
),
GroupShape
(
1
,
64
)]:
for
has_col_major_scales
in
[
True
,
False
]:
for
is_e8m0
in
[
True
,
False
]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
group_shape
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
,
).
register
(
self
.
patterns
)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
group_shape
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
,
).
register
(
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
...
...
vllm/compilation/matcher_utils.py
View file @
a3f8d5dd
...
...
@@ -234,24 +234,30 @@ class MatcherQuantFP8(MatcherCustomOp):
self
,
quant_key
:
QuantKey
,
enabled
:
bool
|
None
=
None
,
use
_col_major_scales
:
bool
=
False
,
use
_e8m0
:
bool
=
False
,
has
_col_major_scales
:
bool
=
False
,
is
_e8m0
:
bool
=
False
,
):
if
enabled
is
None
:
enabled
=
QuantFP8
.
enabled
()
super
().
__init__
(
enabled
)
self
.
quant_key
=
quant_key
self
.
use_col_major_scales
=
use_col_major_scales
self
.
use_e8m0
=
use_e8m0
assert
quant_key
in
QUANT_OPS
,
f
"unsupported quantization scheme
{
quant_key
}
"
self
.
QUANT_OP
=
QUANT_OPS
[
quant_key
]
self
.
has_col_major_scales
=
has_col_major_scales
self
.
is_e8m0
=
is_e8m0
assert
quant_key
.
dtype
==
current_platform
.
fp8_dtype
(),
(
"Only QuantFP8 supported by"
)
assert
quant_key
.
scale2
is
None
self
.
quant_fp8
=
QuantFP8
(
quant_key
.
scale
.
static
,
quant_key
.
scale
.
group_shape
)
self
.
quant_fp8
=
QuantFP8
(
quant_key
.
scale
.
static
,
quant_key
.
scale
.
group_shape
,
column_major_scales
=
has_col_major_scales
,
use_ue8m0
=
is_e8m0
,
)
def
forward_custom
(
self
,
...
...
@@ -264,7 +270,7 @@ class MatcherQuantFP8(MatcherCustomOp):
if
self
.
quant_key
.
scale
.
group_shape
.
is_per_group
():
assert
scale
is
None
scale
=
self
.
make_scale
(
input
,
transposed
=
self
.
use
_col_major_scales
)
scale
=
self
.
make_scale
(
input
,
transposed
=
self
.
has
_col_major_scales
)
finfo
=
torch
.
finfo
(
self
.
quant_key
.
dtype
)
fp8_min
=
finfo
.
min
...
...
@@ -279,7 +285,7 @@ class MatcherQuantFP8(MatcherCustomOp):
eps
=
1e-10
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
scale_ue8m0
=
self
.
use
_e8m0
,
scale_ue8m0
=
self
.
is
_e8m0
,
)
return
result
,
scale
...
...
vllm/compilation/piecewise_backend.py
View file @
a3f8d5dd
...
...
@@ -53,12 +53,7 @@ class PiecewiseBackend:
self
.
is_last_graph
=
piecewise_compile_index
==
total_piecewise_compiles
-
1
self
.
is_full_graph
=
total_piecewise_compiles
==
1
# TODO: we need to generalize encoder compilation to other models
self
.
is_encoder_compilation
=
vllm_backend
.
prefix
in
[
"Qwen2_5_VisionPatchEmbed"
,
"Qwen2_5_VisionPatchMerger"
,
"Qwen2_5_VisionBlock"
,
]
self
.
is_encoder_compilation
=
vllm_backend
.
is_encoder
self
.
compile_ranges
=
self
.
compilation_config
.
get_compile_ranges
()
if
self
.
is_encoder_compilation
:
...
...
vllm/compilation/wrapper.py
View file @
a3f8d5dd
...
...
@@ -171,22 +171,24 @@ class TorchCompileWithNoGuardsWrapper:
compiled_ptr
=
self
.
check_invariants_and_forward
aot_context
=
nullcontext
()
if
envs
.
VLLM_USE_AOT_COMPILE
:
if
hasattr
(
torch
.
_dynamo
.
config
,
"enable_aot_compile"
):
torch
.
_dynamo
.
config
.
enable_aot_compile
=
True
aot_context
=
torch
.
_dynamo
.
config
.
patch
(
enable_aot_compile
=
True
)
else
:
msg
=
"torch._dynamo.config.enable_aot_compile is not "
msg
+=
"available. AOT compile is disabled and please "
msg
+=
"upgrade PyTorch version to use AOT compile."
logger
.
warning
(
msg
)
self
.
_compiled_callable
=
torch
.
compile
(
compiled_ptr
,
fullgraph
=
True
,
dynamic
=
False
,
backend
=
backend
,
options
=
options
,
)
with
aot_context
:
self
.
_compiled_callable
=
torch
.
compile
(
compiled_ptr
,
fullgraph
=
True
,
dynamic
=
False
,
backend
=
backend
,
options
=
options
,
)
if
envs
.
VLLM_USE_BYTECODE_HOOK
and
mode
!=
CompilationMode
.
STOCK_TORCH_COMPILE
:
torch
.
_dynamo
.
convert_frame
.
register_bytecode_hook
(
self
.
bytecode_hook
)
...
...
vllm/config/compilation.py
View file @
a3f8d5dd
...
...
@@ -8,7 +8,7 @@ from dataclasses import field
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Literal
from
pydantic
import
Field
,
TypeAdapter
,
field_validator
from
pydantic
import
ConfigDict
,
Field
,
TypeAdapter
,
field_validator
from
pydantic.dataclasses
import
dataclass
import
vllm.envs
as
envs
...
...
@@ -17,7 +17,6 @@ from vllm.config.utils import (
Range
,
config
,
get_hash_factors
,
handle_deprecated
,
hash_factors
,
)
from
vllm.logger
import
init_logger
...
...
@@ -97,7 +96,7 @@ class CUDAGraphMode(enum.Enum):
@
config
@
dataclass
@
dataclass
(
config
=
ConfigDict
(
extra
=
"forbid"
))
class
PassConfig
:
"""Configuration for custom Inductor passes.
...
...
@@ -127,27 +126,6 @@ class PassConfig:
fuse_allreduce_rms
:
bool
=
Field
(
default
=
None
)
"""Enable flashinfer allreduce fusion."""
# Deprecated flags
enable_fusion
:
bool
=
Field
(
default
=
None
)
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
"""
enable_attn_fusion
:
bool
=
Field
(
default
=
None
)
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_noop
:
bool
=
Field
(
default
=
None
)
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_sequence_parallelism
:
bool
=
Field
(
default
=
None
)
"""Deprecated in: v0.12.0. Use enable_sp instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_async_tp
:
bool
=
Field
(
default
=
None
)
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_fi_allreduce_fusion
:
bool
=
Field
(
default
=
None
)
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
fi_allreduce_fusion_max_size_mb
:
float
|
None
=
None
"""The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a
...
...
@@ -206,15 +184,7 @@ class PassConfig:
Any future fields that don't affect compilation should be excluded.
"""
ignored_fields
=
[
"enable_fusion"
,
"enable_attn_fusion"
,
"enable_noop"
,
"enable_sequence_parallelism"
,
"enable_async_tp"
,
"enable_fi_allreduce_fusion"
,
]
return
hash_factors
(
get_hash_factors
(
self
,
ignored_factors
=
ignored_fields
))
return
hash_factors
(
get_hash_factors
(
self
,
set
()))
@
field_validator
(
"fuse_norm_quant"
,
...
...
@@ -224,12 +194,6 @@ class PassConfig:
"enable_sp"
,
"fuse_gemm_comms"
,
"fuse_allreduce_rms"
,
"enable_fusion"
,
"enable_attn_fusion"
,
"enable_noop"
,
"enable_sequence_parallelism"
,
"enable_async_tp"
,
"enable_fi_allreduce_fusion"
,
mode
=
"wrap"
,
)
@
classmethod
...
...
@@ -242,49 +206,6 @@ class PassConfig:
def
__post_init__
(
self
)
->
None
:
# Handle deprecation and defaults
# Map old flags to new flags and issue warnings
handle_deprecated
(
self
,
"enable_fusion"
,
[
"fuse_norm_quant"
,
"fuse_act_quant"
],
"v0.13.0 or v1.0.0, whichever is sooner"
,
)
handle_deprecated
(
self
,
"enable_attn_fusion"
,
"fuse_attn_quant"
,
"v0.13.0 or v1.0.0, whichever is sooner"
,
)
handle_deprecated
(
self
,
"enable_sequence_parallelism"
,
"enable_sp"
,
"v0.13.0 or v1.0.0, whichever is sooner"
,
)
handle_deprecated
(
self
,
"enable_async_tp"
,
"fuse_gemm_comms"
,
"v0.13.0 or v1.0.0, whichever is sooner"
,
)
handle_deprecated
(
self
,
"enable_fi_allreduce_fusion"
,
"fuse_allreduce_rms"
,
"v0.13.0 or v1.0.0, whichever is sooner"
,
)
handle_deprecated
(
self
,
"enable_noop"
,
"eliminate_noops"
,
"v0.13.0 or v1.0.0, whichever is sooner"
,
)
if
not
self
.
eliminate_noops
:
if
self
.
fuse_norm_quant
or
self
.
fuse_act_quant
:
logger
.
warning_once
(
...
...
@@ -330,7 +251,7 @@ class DynamicShapesType(str, enum.Enum):
@
config
@
dataclass
@
dataclass
(
config
=
ConfigDict
(
extra
=
"forbid"
))
class
DynamicShapesConfig
:
"""Configuration to control/debug torch compile dynamic shapes."""
...
...
@@ -369,7 +290,7 @@ class DynamicShapesConfig:
@
config
@
dataclass
@
dataclass
(
config
=
ConfigDict
(
extra
=
"forbid"
))
class
CompilationConfig
:
"""Configuration for compilation.
...
...
@@ -1011,9 +932,13 @@ class CompilationConfig:
self
.
splitting_ops
=
list
(
self
.
_attention_ops
)
added_default_splitting_ops
=
True
elif
len
(
self
.
splitting_ops
)
==
0
:
logger
.
warning_once
(
"Using piecewise compilation with empty splitting_ops"
)
if
(
self
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
or
self
.
cudagraph_mode
==
CUDAGraphMode
.
FULL_AND_PIECEWISE
):
logger
.
warning_once
(
"Using piecewise compilation with empty splitting_ops"
)
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
:
logger
.
warning_once
(
"Piecewise compilation with empty splitting_ops do not"
...
...
vllm/config/kv_transfer.py
View file @
a3f8d5dd
...
...
@@ -64,6 +64,11 @@ class KVTransferConfig:
enable_permute_local_kv
:
bool
=
False
"""Experiment feature flag to enable HND to NHD KV Transfer"""
kv_load_failure_policy
:
Literal
[
"recompute"
,
"fail"
]
=
"recompute"
"""Policy for handling KV cache load failures.
'recompute': reschedule the request to recompute failed blocks (default)
'fail': immediately fail the request with an error finish reason"""
def
compute_hash
(
self
)
->
str
:
"""
WARNING: Whenever a new field is added to this config,
...
...
vllm/config/model.py
View file @
a3f8d5dd
...
...
@@ -8,7 +8,7 @@ from functools import cached_property
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
cast
,
get_args
import
torch
from
pydantic
import
ConfigDict
,
SkipValidation
,
field_validator
,
model_validator
from
pydantic
import
ConfigDict
,
Field
,
field_validator
,
model_validator
from
pydantic.dataclasses
import
dataclass
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
from
transformers.configuration_utils
import
ALLOWED_LAYER_TYPES
...
...
@@ -73,17 +73,6 @@ logger = init_logger(__name__)
RunnerOption
=
Literal
[
"auto"
,
RunnerType
]
ConvertType
=
Literal
[
"none"
,
"embed"
,
"classify"
,
"reward"
]
ConvertOption
=
Literal
[
"auto"
,
ConvertType
]
TaskOption
=
Literal
[
"auto"
,
"generate"
,
"embedding"
,
"embed"
,
"classify"
,
"score"
,
"reward"
,
"transcription"
,
"draft"
,
]
TokenizerMode
=
Literal
[
"auto"
,
"hf"
,
"slow"
,
"mistral"
,
"deepseek_v32"
]
ModelDType
=
Literal
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
]
LogprobsMode
=
Literal
[
...
...
@@ -93,12 +82,6 @@ HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig]
ModelImpl
=
Literal
[
"auto"
,
"vllm"
,
"transformers"
,
"terratorch"
]
LayerBlockType
=
Literal
[
"attention"
,
"linear_attention"
,
"mamba"
]
_RUNNER_TASKS
:
dict
[
RunnerType
,
list
[
TaskOption
]]
=
{
"generate"
:
[
"generate"
,
"transcription"
],
"pooling"
:
[
"embedding"
,
"embed"
,
"classify"
,
"score"
,
"reward"
],
"draft"
:
[
"draft"
],
}
_RUNNER_CONVERTS
:
dict
[
RunnerType
,
list
[
ConvertType
]]
=
{
"generate"
:
[],
"pooling"
:
[
"embed"
,
"classify"
,
"reward"
],
...
...
@@ -126,13 +109,7 @@ class ModelConfig:
"""Convert the model using adapters defined in
[vllm.model_executor.models.adapters][]. The most common use case is to
adapt a text generation model to be used for pooling tasks."""
task
:
TaskOption
|
None
=
None
"""[DEPRECATED] The task to use the model for. If the model supports more
than one model runner, this is used to select which model runner to run.
Note that the model may support other tasks using the same model runner.
"""
tokenizer
:
SkipValidation
[
str
]
=
None
# type: ignore
tokenizer
:
str
=
Field
(
default
=
None
)
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
tokenizer_mode
:
TokenizerMode
|
str
=
"auto"
...
...
@@ -187,7 +164,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
max_model_len
:
SkipValidation
[
int
]
=
None
# type: ignore
max_model_len
:
int
=
Field
(
default
=
None
,
gt
=
0
)
"""Model context length (prompt and output). If unspecified, will be
automatically derived from the model config.
...
...
@@ -198,7 +175,7 @@ class ModelConfig:
- 25.6k -> 25,600"""
spec_target_max_model_len
:
int
|
None
=
None
"""Specify the maximum length for spec decoding draft models."""
quantization
:
SkipValidation
[
QuantizationMethods
|
None
]
=
None
quantization
:
QuantizationMethods
|
str
|
None
=
None
"""Method used to quantize the weights. If `None`, we first check the
`quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to
...
...
@@ -338,7 +315,6 @@ class ModelConfig:
ignored_factors
=
{
"runner"
,
"convert"
,
"task"
,
"tokenizer"
,
"tokenizer_mode"
,
"seed"
,
...
...
@@ -513,97 +489,6 @@ class ModelConfig:
is_generative_model
=
registry
.
is_text_generation_model
(
architectures
,
self
)
is_pooling_model
=
registry
.
is_pooling_model
(
architectures
,
self
)
def
_task_to_convert
(
task
:
TaskOption
)
->
ConvertType
:
if
task
==
"embedding"
or
task
==
"embed"
:
return
"embed"
if
task
==
"classify"
:
return
"classify"
if
task
==
"reward"
:
logger
.
warning
(
"Pooling models now default support all pooling; "
"you can use it without any settings."
)
return
"embed"
if
task
==
"score"
:
new_task
=
self
.
_get_default_pooling_task
(
architectures
)
return
"classify"
if
new_task
==
"classify"
else
"embed"
return
"none"
if
self
.
task
is
not
None
:
runner
:
RunnerOption
=
"auto"
convert
:
ConvertOption
=
"auto"
msg_prefix
=
(
"The 'task' option has been deprecated and will be "
"removed in v0.13.0 or v1.0, whichever comes first."
)
msg_hint
=
"Please remove this option."
is_generative_task
=
self
.
task
in
_RUNNER_TASKS
[
"generate"
]
is_pooling_task
=
self
.
task
in
_RUNNER_TASKS
[
"pooling"
]
if
is_generative_model
and
is_pooling_model
:
if
is_generative_task
:
runner
=
"generate"
convert
=
"auto"
msg_hint
=
(
"Please replace this option with `--runner "
"generate` to continue using this model "
"as a generative model."
)
elif
is_pooling_task
:
runner
=
"pooling"
convert
=
"auto"
msg_hint
=
(
"Please replace this option with `--runner "
"pooling` to continue using this model "
"as a pooling model."
)
else
:
# task == "auto"
pass
elif
is_generative_model
or
is_pooling_model
:
if
is_generative_task
:
runner
=
"generate"
convert
=
"auto"
msg_hint
=
"Please remove this option"
elif
is_pooling_task
:
runner
=
"pooling"
convert
=
_task_to_convert
(
self
.
task
)
msg_hint
=
(
"Please replace this option with `--convert "
f
"
{
convert
}
` to continue using this model "
"as a pooling model."
)
else
:
# task == "auto"
pass
else
:
# Neither generative nor pooling model - try to convert if possible
if
is_pooling_task
:
runner
=
"pooling"
convert
=
_task_to_convert
(
self
.
task
)
msg_hint
=
(
"Please replace this option with `--runner pooling "
f
"--convert
{
convert
}
` to continue using this model "
"as a pooling model."
)
else
:
debug_info
=
{
"architectures"
:
architectures
,
"is_generative_model"
:
is_generative_model
,
"is_pooling_model"
:
is_pooling_model
,
}
raise
AssertionError
(
"The model should be a generative or "
"pooling model when task is set to "
f
"
{
self
.
task
!
r
}
. Found:
{
debug_info
}
"
)
self
.
runner
=
runner
self
.
convert
=
convert
msg
=
f
"
{
msg_prefix
}
{
msg_hint
}
"
warnings
.
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
self
.
runner_type
=
self
.
_get_runner_type
(
architectures
,
self
.
runner
)
self
.
convert_type
=
self
.
_get_convert_type
(
architectures
,
self
.
runner_type
,
self
.
convert
...
...
@@ -657,6 +542,11 @@ class ModelConfig:
self
.
original_max_model_len
=
self
.
max_model_len
self
.
max_model_len
=
self
.
get_and_verify_max_len
(
self
.
max_model_len
)
if
self
.
is_encoder_decoder
:
self
.
mm_processor_cache_gb
=
0
logger
.
info
(
"Encoder-decoder model detected, disabling mm processor cache."
)
# Init multimodal config if needed
if
self
.
_model_info
.
supports_multimodal
:
if
(
...
...
@@ -710,6 +600,14 @@ class ModelConfig:
self
.
_verify_cuda_graph
()
self
.
_verify_bnb_config
()
@
field_validator
(
"tokenizer"
,
"max_model_len"
,
mode
=
"wrap"
)
@
classmethod
def
_skip_none_validation
(
cls
,
value
:
Any
,
handler
:
Callable
)
->
Any
:
"""Skip validation if the value is `None` when initialisation is delayed."""
if
value
is
None
:
return
value
return
handler
(
value
)
@
field_validator
(
"tokenizer_mode"
,
mode
=
"after"
)
def
_lowercase_tokenizer_mode
(
cls
,
tokenizer_mode
:
str
)
->
str
:
return
tokenizer_mode
.
lower
()
...
...
@@ -723,10 +621,19 @@ class ModelConfig:
@
model_validator
(
mode
=
"after"
)
def
validate_model_config_after
(
self
:
"ModelConfig"
)
->
"ModelConfig"
:
"""Called after __post_init__"""
if
not
isinstance
(
self
.
tokenizer
,
str
):
raise
ValueError
(
"tokenizer must be a string after __post_init__."
)
raise
ValueError
(
f
"tokenizer must be a string, got "
f
"
{
type
(
self
.
tokenizer
).
__name__
}
:
{
self
.
tokenizer
!
r
}
. "
"Please provide a valid tokenizer path or HuggingFace model ID."
)
if
not
isinstance
(
self
.
max_model_len
,
int
):
raise
ValueError
(
"max_model_len must be an integer after __post_init__."
)
raise
ValueError
(
f
"max_model_len must be a positive integer, "
f
"got
{
type
(
self
.
max_model_len
).
__name__
}
:
{
self
.
max_model_len
!
r
}
. "
"Example: max_model_len=2048"
)
return
self
def
_get_transformers_backend_cls
(
self
)
->
str
:
...
...
@@ -906,6 +813,13 @@ class ModelConfig:
runner_type
:
RunnerType
,
convert
:
ConvertOption
,
)
->
ConvertType
:
if
convert
==
"reward"
:
logger
.
warning
(
"`--convert reward` is deprecated and will be removed in v0.15. "
"Please use `--convert embed` instead."
)
return
"embed"
if
convert
!=
"auto"
:
return
convert
...
...
@@ -921,22 +835,6 @@ class ModelConfig:
return
convert_type
def
_get_default_pooling_task
(
self
,
architectures
:
list
[
str
],
)
->
Literal
[
"embed"
,
"classify"
,
"reward"
]:
if
self
.
registry
.
is_cross_encoder_model
(
architectures
,
self
):
return
"classify"
for
arch
in
architectures
:
match
=
try_match_architecture_defaults
(
arch
,
runner_type
=
"pooling"
)
if
match
:
_
,
(
_
,
convert_type
)
=
match
assert
convert_type
!=
"none"
return
convert_type
return
"embed"
def
_parse_quant_hf_config
(
self
,
hf_config
:
PretrainedConfig
):
quant_cfg
=
getattr
(
hf_config
,
"quantization_config"
,
None
)
if
quant_cfg
is
None
:
...
...
@@ -1308,7 +1206,15 @@ class ModelConfig:
//
block
.
attention
.
n_heads_in_group
)
raise
RuntimeError
(
"Couldn't determine number of kv heads"
)
raise
RuntimeError
(
"Could not determine the number of key-value attention heads "
"from model configuration. "
f
"Model:
{
self
.
model
}
, Architecture:
{
self
.
architectures
}
. "
"This usually indicates an unsupported model architecture or "
"missing configuration. "
"Please check if your model is supported at: "
"https://docs.vllm.ai/en/latest/models/supported_models.html"
)
if
self
.
is_attention_free
:
return
0
...
...
@@ -1902,6 +1808,7 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
(
"ForTextEncoding"
,
(
"pooling"
,
"embed"
)),
(
"EmbeddingModel"
,
(
"pooling"
,
"embed"
)),
(
"ForSequenceClassification"
,
(
"pooling"
,
"classify"
)),
(
"ForTokenClassification"
,
(
"pooling"
,
"classify"
)),
(
"ForAudioClassification"
,
(
"pooling"
,
"classify"
)),
(
"ForImageClassification"
,
(
"pooling"
,
"classify"
)),
(
"ForVideoClassification"
,
(
"pooling"
,
"classify"
)),
...
...
vllm/config/parallel.py
View file @
a3f8d5dd
...
...
@@ -317,11 +317,6 @@ class ParallelConfig:
"num_redundant_experts."
)
if
self
.
prefill_context_parallel_size
>
1
:
raise
ValueError
(
"Prefill context parallelism is not fully supported. "
"Please set prefill_context_parallel_size to 1."
)
return
self
@
property
...
...
vllm/config/pooler.py
View file @
a3f8d5dd
...
...
@@ -111,13 +111,15 @@ class PoolerConfig:
def
get_use_activation
(
o
:
object
):
if
softmax
:
=
getattr
(
o
,
"softmax"
,
None
)
is
not
None
:
logger
.
warning_once
(
"softmax will be deprecated, please use use_activation instead."
"softmax will be deprecated and will be removed in v0.15. "
"Please use use_activation instead."
)
return
softmax
if
activation
:
=
getattr
(
o
,
"activation"
,
None
)
is
not
None
:
logger
.
warning_once
(
"activation will be deprecated, please use use_activation instead."
"activation will be deprecated and will be removed in v0.15. "
"Please use use_activation instead."
)
return
activation
...
...
vllm/config/scheduler.py
View file @
a3f8d5dd
...
...
@@ -122,10 +122,12 @@ class SchedulerConfig:
the default scheduler. Can be a class directly or the path to a class of
form "mod.custom_class"."""
disable_hybrid_kv_cache_manager
:
bool
=
Fals
e
disable_hybrid_kv_cache_manager
:
bool
|
None
=
Non
e
"""If set to True, KV cache manager will allocate the same size of KV cache
for all attention layers even if there are multiple type of attention layers
like full attention and sliding window attention.
If set to None, the default value will be determined based on the environment
and starting configuration.
"""
async_scheduling
:
bool
=
False
...
...
vllm/config/utils.py
View file @
a3f8d5dd
...
...
@@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
)
def
getattr_iter
(
object
:
object
,
names
:
Iterable
[
str
],
default
:
Any
)
->
Any
:
def
getattr_iter
(
object
:
object
,
names
:
Iterable
[
str
],
default
:
Any
,
warn
:
bool
=
False
)
->
Any
:
"""
A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances.
In the case where the first name in `names` is the preferred name, and
any other names are deprecated aliases, setting `warn=True` will log a
warning when a deprecated name is used.
"""
for
name
in
names
:
for
i
,
name
in
enumerate
(
names
)
:
if
hasattr
(
object
,
name
):
if
warn
and
i
>
0
:
logger
.
warning_once
(
"%s contains a deprecated attribute name '%s'. "
"Please use the preferred attribute name '%s' instead."
,
type
(
object
).
__name__
,
name
,
names
[
0
],
)
return
getattr
(
object
,
name
)
return
default
...
...
vllm/config/vllm.py
View file @
a3f8d5dd
...
...
@@ -666,8 +666,9 @@ class VllmConfig:
default_config
=
OPTIMIZATION_LEVEL_TO_CONFIG
[
self
.
optimization_level
]
self
.
_apply_optimization_level_defaults
(
default_config
)
if
(
self
.
compilation_config
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
self
.
compilation_config
.
cudagraph_mode
.
requires_piecewise_compilation
()
and
self
.
compilation_config
.
mode
!=
CompilationMode
.
VLLM_COMPILE
):
logger
.
info
(
...
...
@@ -692,22 +693,29 @@ class VllmConfig:
if
current_platform
.
support_static_graph_mode
():
# if cudagraph_mode has full cudagraphs, we need to check support
if
(
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
and
self
.
model_config
is
not
None
):
if
self
.
model_config
.
pooler_config
is
not
None
:
if
model_config
:
=
self
.
model_config
:
if
(
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
and
model_config
.
pooler_config
is
not
None
)
:
logger
.
warning_once
(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self
.
compilation_config
.
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
elif
self
.
model_config
.
is_encoder_decoder
:
logger
.
warning_once
(
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
elif
(
model_config
.
is_encoder_decoder
and
self
.
compilation_config
.
cudagraph_mode
not
in
(
CUDAGraphMode
.
NONE
,
CUDAGraphMode
.
FULL_DECODE_ONLY
)
):
logger
.
info_once
(
"Encoder-decoder models do not support %s. "
"Overriding cudagraph_mode to FULL_DECODE_ONLY."
,
self
.
compilation_config
.
cudagraph_mode
.
name
,
)
self
.
compilation_config
.
cudagraph_mode
=
(
CUDAGraphMode
.
FULL_DECODE_ONLY
)
self
.
compilation_config
.
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
# disable cudagraph when enforce eager execution
if
self
.
model_config
is
not
None
and
self
.
model_config
.
enforce_eager
:
...
...
@@ -742,27 +750,17 @@ class VllmConfig:
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self
.
_set_compile_ranges
()
if
self
.
model_config
and
self
.
model_config
.
is_encoder_decoder
:
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
self
.
scheduler_config
.
max_num_encoder_input_tokens
=
(
MULTIMODAL_REGISTRY
.
get_encdec_max_encoder_len
(
self
.
model_config
)
)
logger
.
debug
(
"
Encoder-decoder model detected: sett
ing "
"
`max_num_encoder_input_tokens` to encoder length (%s)"
,
self
.
scheduler_config
.
max_num_encoder_input_tokens
,
if
(
self
.
model_config
and
self
.
model_config
.
architecture
==
"WhisperForConditionalGeneration"
and
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
!=
"spawn"
)
:
logger
.
warning
(
"Whisper is known to have issues with "
"
forked workers. If startup is hang
ing
,
"
"
try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
if
(
self
.
model_config
.
architecture
==
"WhisperForConditionalGeneration"
and
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
!=
"spawn"
):
logger
.
warning
(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
if
(
self
.
kv_events_config
is
not
None
...
...
@@ -812,11 +810,6 @@ class VllmConfig:
f
"(
{
self
.
parallel_config
.
cp_kv_cache_interleave_size
}
)."
)
assert
(
self
.
parallel_config
.
cp_kv_cache_interleave_size
==
1
or
self
.
speculative_config
is
None
),
"MTP with cp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode
self
.
compilation_config
.
set_splitting_ops_for_v1
(
all2all_backend
=
self
.
parallel_config
.
all2all_backend
,
...
...
@@ -894,17 +887,48 @@ class VllmConfig:
if
not
self
.
instance_id
:
self
.
instance_id
=
random_uuid
()[:
5
]
if
not
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
:
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if
not
current_platform
.
support_hybrid_kv_cache
():
# Hybrid KV cache manager is not supported on non-GPU platforms.
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
True
# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
# disables it
# - No preference: auto-disable for unsupported features (e.g. kv connector)
# - Explicit disable (--disable-kv-cache-manager): always respect it
need_disable_hybrid_kv_cache_manager
=
False
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if
not
current_platform
.
support_hybrid_kv_cache
():
# Hybrid KV cache manager is not supported on non-GPU platforms.
need_disable_hybrid_kv_cache_manager
=
True
if
self
.
kv_events_config
is
not
None
:
# Hybrid KV cache manager is not compatible with KV events.
need_disable_hybrid_kv_cache_manager
=
True
if
(
self
.
model_config
is
not
None
and
self
.
model_config
.
attention_chunk_size
is
not
None
):
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_eagle
()
):
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
need_disable_hybrid_kv_cache_manager
=
True
elif
not
envs
.
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE
:
logger
.
warning
(
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
)
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
need_disable_hybrid_kv_cache_manager
=
True
if
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
is
None
:
# Default to disable HMA, but only if the user didn't express a preference.
if
self
.
kv_transfer_config
is
not
None
:
# NOTE(Kuntai): turn HMA off for connector for now.
# TODO(Kuntai): have a more elegent solution to check and
# turn off HMA for connector that does not support HMA.
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
need_disable_hybrid_kv_cache_manager
=
True
logger
.
warning
(
"Turning off hybrid kv cache manager because "
"`--kv-transfer-config` is set. This will reduce the "
...
...
@@ -912,33 +936,26 @@ class VllmConfig:
"or Mamba attention. If you are a developer of kv connector"
", please consider supporting hybrid kv cache manager for "
"your connector by making sure your connector is a subclass"
" of `SupportsHMA` defined in kv_connector/v1/base.py."
" of `SupportsHMA` defined in kv_connector/v1/base.py and"
" use --no-disable-hybrid-kv-cache-manager to start vLLM."
)
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
True
if
self
.
kv_events_config
is
not
None
:
# Hybrid KV cache manager is not compatible with KV events.
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
True
if
(
self
.
model_config
is
not
None
and
self
.
model_config
.
attention_chunk_size
is
not
None
):
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_eagle
()
):
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
True
elif
not
envs
.
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE
:
logger
.
warning
(
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
)
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
True
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
(
need_disable_hybrid_kv_cache_manager
)
elif
(
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
is
False
and
need_disable_hybrid_kv_cache_manager
):
raise
ValueError
(
"Hybrid KV cache manager was explicitly enabled but is not "
"supported in this configuration. Consider omitting the "
"--no-disable-hybrid-kv-cache-manager flag to let vLLM decide"
" automatically."
)
if
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
is
None
:
# Default to enable HMA if not explicitly disabled by user or logic above.
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
False
if
self
.
compilation_config
.
debug_dump_path
:
self
.
compilation_config
.
debug_dump_path
=
(
...
...
@@ -1006,7 +1023,7 @@ class VllmConfig:
max_graph_size = min(max_num_seqs * 2, 512)
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
# up to max_graph_size
cuda
_
graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
cudagraph_
capture_
sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_graph_size + 1, 16))
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
a3f8d5dd
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
pickle
import
threading
import
time
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
...
...
@@ -43,6 +44,33 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
from_bytes_big
=
functools
.
partial
(
int
.
from_bytes
,
byteorder
=
"big"
)
# Memory fence for cross-process shared memory visibility.
# Required for correct producer-consumer synchronization when using
# shared memory without locks.
_memory_fence_lock
=
threading
.
Lock
()
def
memory_fence
():
"""
Full memory barrier for shared memory synchronization.
Ensures all prior memory writes are visible to other processes before
any subsequent reads. This is critical for lock-free producer-consumer
patterns using shared memory.
Implementation acquires and immediately releases a lock. Python's
threading.Lock provides sequentially consistent memory barrier semantics
across all major platforms (POSIX, Windows). This is a lightweight
operation (~20ns) that guarantees:
- All stores before the barrier are visible to other threads/processes
- All loads after the barrier see the latest values
"""
# Lock acquire/release provides full memory barrier semantics.
# Using context manager ensures lock release even on exceptions.
with
_memory_fence_lock
:
pass
def
to_bytes_big
(
value
:
int
,
size
:
int
)
->
bytes
:
return
value
.
to_bytes
(
size
,
byteorder
=
"big"
)
...
...
@@ -414,6 +442,10 @@ class MessageQueue:
n_warning
=
1
while
True
:
with
self
.
buffer
.
get_metadata
(
self
.
current_idx
)
as
metadata_buffer
:
# Memory fence ensures we see the latest read flags from readers.
# Without this, we may read stale flags from our CPU cache and
# spin indefinitely even though readers have completed.
memory_fence
()
read_count
=
sum
(
metadata_buffer
[
1
:])
written_flag
=
metadata_buffer
[
0
]
if
written_flag
and
read_count
!=
self
.
buffer
.
n_reader
:
...
...
@@ -458,6 +490,10 @@ class MessageQueue:
metadata_buffer
[
i
]
=
0
# mark the block as written
metadata_buffer
[
0
]
=
1
# Memory fence ensures the write is visible to readers on other cores
# before we proceed. Without this, readers may spin indefinitely
# waiting for a write that's stuck in our CPU's store buffer.
memory_fence
()
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
break
...
...
@@ -473,6 +509,10 @@ class MessageQueue:
n_warning
=
1
while
True
:
with
self
.
buffer
.
get_metadata
(
self
.
current_idx
)
as
metadata_buffer
:
# Memory fence ensures we see the latest writes from the writer.
# Without this, we may read stale flags from our CPU cache
# and spin indefinitely even though writer has updated them.
memory_fence
()
read_flag
=
metadata_buffer
[
self
.
local_reader_rank
+
1
]
written_flag
=
metadata_buffer
[
0
]
if
not
written_flag
or
read_flag
:
...
...
@@ -513,6 +553,10 @@ class MessageQueue:
# caller has read from the buffer
# set the read flag
metadata_buffer
[
self
.
local_reader_rank
+
1
]
=
1
# Memory fence ensures the read flag is visible to the writer.
# Without this, writer may not see our read completion and
# could wait indefinitely for all readers to finish.
memory_fence
()
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
self
.
_read_spin_timer
.
record_activity
()
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment