Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a99300bd
Commit
a99300bd
authored
Sep 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev
parents
cc3e01c7
5438967f
Changes
512
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
722 additions
and
301 deletions
+722
-301
vllm/benchmarks/throughput.py
vllm/benchmarks/throughput.py
+8
-0
vllm/compilation/activation_quant_fusion.py
vllm/compilation/activation_quant_fusion.py
+139
-35
vllm/compilation/backends.py
vllm/compilation/backends.py
+5
-15
vllm/compilation/base_static_graph.py
vllm/compilation/base_static_graph.py
+1
-4
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+21
-1
vllm/compilation/cuda_graph.py
vllm/compilation/cuda_graph.py
+4
-4
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+36
-3
vllm/compilation/fix_functionalization.py
vllm/compilation/fix_functionalization.py
+17
-0
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+22
-48
vllm/compilation/fusion_attn.py
vllm/compilation/fusion_attn.py
+191
-72
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+20
-0
vllm/compilation/monitor.py
vllm/compilation/monitor.py
+1
-1
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+1
-1
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+2
-0
vllm/config/__init__.py
vllm/config/__init__.py
+131
-64
vllm/config/cache.py
vllm/config/cache.py
+9
-7
vllm/config/compilation.py
vllm/config/compilation.py
+9
-8
vllm/config/parallel.py
vllm/config/parallel.py
+103
-36
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+1
-1
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
512 of 512+
files are displayed.
Plain diff
Email patch
vllm/benchmarks/throughput.py
View file @
a99300bd
...
...
@@ -434,6 +434,14 @@ def validate_args(args):
if
args
.
backend
==
"mii"
and
args
.
tokenizer
!=
args
.
model
:
raise
ValueError
(
"Tokenizer must be the same as the model for MII backend."
)
# --data-parallel is not supported currently.
# https://github.com/vllm-project/vllm/issues/16222
if
args
.
data_parallel_size
>
1
:
raise
ValueError
(
"Data parallel is not supported in offline benchmark, "
"please use benchmark serving instead"
)
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
...
vllm/compilation/activation_quant_fusion.py
View file @
a99300bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
import
torch
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._inductor.pattern_matcher
import
(
PatternMatcherPass
,
fwd_only
,
register_replacement
)
from
torch._ops
import
OpOverload
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8StaticTensorSym
,
kNvfp4Quant
,
kStaticTensorScale
)
from
vllm.platforms
import
current_platform
from
.fusion
import
QUANT_OPS
,
empty_bf16
,
empty_fp32
,
empty_i32
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
SILU_MUL_OP
=
torch
.
ops
.
_C
.
silu_and_mul
.
default
def
silu_mul_pattern_static
(
result
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at1
=
auto_functionalized
(
torch
.
ops
.
_C
.
silu_and_mul
.
default
,
result
=
result_silu_mul
,
input
=
input
)
at2
=
auto_functionalized
(
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
,
result
=
result
,
input
=
at1
[
1
],
scale
=
scale
)
return
at2
[
1
]
# FUSED_OPS: dict[QuantKey, OpOverload] = {
# kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
# }
# silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr(
# torch.ops._C, "silu_and_mul_nvfp4_quant"))
# if silu_and_mul_nvfp4_quant_supported:
# FUSED_OPS[
# kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
def
silu_mul_replacement_static
(
result
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at
=
auto_functionalized
(
torch
.
ops
.
_C
.
silu_and_mul_quant
.
default
,
result
=
result
,
input
=
input
,
scale
=
scale
)
return
at
[
1
]
class
ActivationQuantPattern
(
ABC
):
"""
The base class for Activation+Quant fusions.
Should not be used directly.
"""
def
__init__
(
self
,
quant_key
:
QuantKey
,
):
self
.
quant_key
=
quant_key
self
.
quant_dtype
=
quant_key
.
dtype
def
empty_bf16
(
*
args
,
**
kwargs
):
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
assert
self
.
quant_key
in
QUANT_OPS
,
\
f
"unsupported quantization scheme
{
self
.
quant_key
}
"
self
.
QUANT_OP
=
QUANT_OPS
[
self
.
quant_key
]
assert
self
.
quant_key
in
FUSED_OPS
,
\
f
"unsupported fusion scheme
{
self
.
quant_key
}
"
self
.
FUSED_OP
=
FUSED_OPS
[
self
.
quant_key
]
def
empty_
fp8
(
*
args
,
**
kwargs
):
fp8
=
current_platform
.
fp8_dtype
()
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
fp8
,
device
=
"cuda"
)
def
empty_
quant
(
self
,
*
args
,
**
kwargs
):
kwargs
=
{
'dtype'
:
self
.
quant_dtype
,
'device'
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
@
abstractmethod
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
raise
NotImplementedError
def
empty_fp32
(
*
args
,
**
kwargs
):
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
class
SiluMulFp8StaticQuantPattern
(
ActivationQuantPattern
):
"""
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def
__init__
(
self
,
symmetric
:
bool
=
True
):
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
)
super
().
__init__
(
quant_key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
result
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at1
=
auto_functionalized
(
SILU_MUL_OP
,
result
=
result_silu_mul
,
input
=
input
)
at2
=
auto_functionalized
(
self
.
QUANT_OP
,
result
=
result
,
input
=
at1
[
1
],
scale
=
scale
)
return
at2
[
1
]
def
replacement
(
result
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at
=
auto_functionalized
(
self
.
FUSED_OP
,
result
=
result
,
input
=
input
,
scale
=
scale
)
return
at
[
1
]
inputs
=
[
self
.
empty_quant
(
5
,
4
),
# result
empty_bf16
(
5
,
4
),
# result_silu_mul
empty_bf16
(
5
,
4
),
# input
empty_fp32
(
1
,
1
)
# scale
]
register_replacement
(
pattern
,
replacement
,
inputs
,
fwd_only
,
pm_pass
)
class
SiluMulNvfp4QuantPattern
(
ActivationQuantPattern
):
"""
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def
__init__
(
self
):
super
().
__init__
(
kNvfp4Quant
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at1
=
auto_functionalized
(
SILU_MUL_OP
,
result
=
result_silu_mul
,
input
=
input
)
at2
=
auto_functionalized
(
self
.
QUANT_OP
,
output
=
result
,
input
=
at1
[
1
],
output_scale
=
output_scale
,
input_scale
=
scale
)
return
at2
[
1
],
at2
[
2
]
def
replacement
(
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
at
=
auto_functionalized
(
self
.
FUSED_OP
,
result
=
result
,
result_block_scale
=
output_scale
,
input
=
input
,
input_global_scale
=
scale
)
return
at
[
1
],
at
[
2
]
inputs
=
[
self
.
empty_quant
(
5
,
32
),
# result
empty_i32
(
128
,
4
),
# output_scale
empty_bf16
(
5
,
64
),
# result_silu_mul
empty_bf16
(
5
,
64
),
# input
empty_fp32
(
1
,
1
)
# scale
]
register_replacement
(
pattern
,
replacement
,
inputs
,
fwd_only
,
pm_pass
)
class
ActivationQuantFusionPass
(
VllmInductorPass
):
...
...
@@ -61,21 +162,19 @@ class ActivationQuantFusionPass(VllmInductorPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"activation_quant_fusion_pass"
)
inputs
=
[
empty_fp8
(
5
,
4
),
# Quant output
empty_bf16
(
5
,
4
),
# Silu_and_mul output
empty_bf16
(
5
,
4
),
# Input
empty_fp32
(
1
,
1
)
# Scale
]
register_replacement
(
silu_mul_pattern_static
,
silu_mul_replacement_static
,
inputs
,
fwd_only
,
self
.
patterns
)
pattern_silu_mul_fp8
=
SiluMulFp8StaticQuantPattern
()
pattern_silu_mul_fp8
.
register
(
self
.
patterns
)
if
silu_and_mul_nvfp4_quant_supported
:
pattern_silu_mul_nvfp4
=
SiluMulNvfp4QuantPattern
()
pattern_silu_mul_nvfp4
.
register
(
self
.
patterns
)
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
self
.
begin
()
...
...
@@ -87,3 +186,8 @@ class ActivationQuantFusionPass(VllmInductorPass):
self
.
dump_graph
(
graph
,
"after_act_quant_fusion"
)
self
.
end_and_log
()
def
uuid
(
self
):
return
VllmInductorPass
.
hash_source
(
self
,
ActivationQuantPattern
,
SiluMulFp8StaticQuantPattern
,
SiluMulNvfp4QuantPattern
)
vllm/compilation/backends.py
View file @
a99300bd
...
...
@@ -271,7 +271,7 @@ def split_graph(graph: fx.GraphModule,
outputs
.
append
(
SplitItem
(
name
,
graph_id
,
(
graph_id
in
split_op_graphs
),
module
))
# sort by inte
t
ger graph_id, rather than string name
# sort by integer graph_id, rather than string name
outputs
.
sort
(
key
=
lambda
x
:
x
.
graph_id
)
return
split_gm
,
outputs
...
...
@@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
def
__init__
(
self
,
module
:
torch
.
fx
.
GraphModule
,
compile_submod_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
graph_pool
,
vllm_backend
:
"VllmBackend"
):
vllm_backend
:
"VllmBackend"
):
super
().
__init__
(
module
)
from
torch._guards
import
detect_fake_mode
self
.
fake_mode
=
detect_fake_mode
()
self
.
compile_submod_names
=
compile_submod_names
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
graph_pool
=
graph_pool
self
.
vllm_config
=
vllm_config
self
.
vllm_backend
=
vllm_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
...
...
@@ -359,7 +358,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
runnable
=
piecewise_backend
,
vllm_config
=
self
.
vllm_config
,
runtime_mode
=
CUDAGraphMode
.
PIECEWISE
,
graph_pool
=
self
.
graph_pool
,
cudagraph_options
=
CUDAGraphOptions
(
debug_log_enable
=
piecewise_backend
.
is_first_graph
,
gc_disable
=
not
piecewise_backend
.
is_first_graph
,
...
...
@@ -405,7 +403,6 @@ class VllmBackend:
vllm_config
:
VllmConfig
compilation_config
:
CompilationConfig
graph_pool
:
Any
_called
:
bool
=
False
# the graph we compiled
graph
:
fx
.
GraphModule
...
...
@@ -427,19 +424,12 @@ class VllmBackend:
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. la
u
nguage_model, vision_model, etc.
# e.g. language_model, vision_model, etc.
# when multiple parts are initialized as independent
# models, we need to use the model_tag to distinguish
# them, e.g. backbone (default), eagle_head, etc.
self
.
prefix
=
prefix
or
model_tag
global_graph_pool
=
current_platform
.
get_global_graph_pool
()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self
.
graph_pool
=
global_graph_pool
# Passes to run on the graph post-grad.
self
.
post_grad_pass_manager
=
PostGradPassManager
()
...
...
@@ -484,7 +474,7 @@ class VllmBackend:
factors
=
[]
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect
s
the computation graph.
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash
=
envs
.
compute_hash
()
factors
.
append
(
env_hash
)
...
...
@@ -586,7 +576,7 @@ class VllmBackend:
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter
(
self
.
split_gm
,
submod_names_to_compile
,
self
.
vllm_config
,
self
.
graph_pool
,
self
.
vllm_config
,
self
).
run
(
*
example_inputs
)
graph_path
=
os
.
path
.
join
(
local_cache_dir
,
"computation_graph.py"
)
...
...
vllm/compilation/base_static_graph.py
View file @
a99300bd
...
...
@@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol):
"""
def
__init__
(
self
,
runnable
:
Callable
,
vllm_config
:
VllmConfig
,
runtime_mode
:
CUDAGraphMode
,
graph_pool
:
Any
,
**
kwargs
):
runtime_mode
:
CUDAGraphMode
,
**
kwargs
):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
...
...
@@ -25,9 +25,6 @@ class AbstractStaticGraphWrapper(Protocol):
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
...
...
vllm/compilation/collective_fusion.py
View file @
a99300bd
...
...
@@ -10,6 +10,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from
torch._inductor.pattern_matcher
import
PatternMatcherPass
from
torch.distributed._symmetric_memory
import
enable_symm_mem_for_group
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tp_group
,
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
(
...
...
@@ -18,6 +19,7 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
...
@@ -348,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
class
AsyncTPPass
(
VllmInductorPass
):
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
...
...
@@ -401,6 +404,18 @@ if flashinfer_comm is not None:
6
:
MiB
//
2
,
# 512KB
8
:
MiB
//
2
,
# 512KB
}
try
:
_FI_MAX_SIZES
.
update
({
int
(
k
):
int
(
float
(
v
)
*
MiB
)
for
k
,
v
in
envs
.
VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB
.
items
()
})
except
Exception
as
e
:
raise
ValueError
(
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: "
+
str
(
e
))
from
e
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE
=
MiB
//
2
...
...
@@ -465,7 +480,8 @@ if flashinfer_comm is not None:
quant_out
=
quant_out
,
scale_out
=
scale_out
,
# in vllm we only support swizzled layout
layout_code
=
flashinfer_comm
.
FP4QuantizationSFLayout
.
SWIZZLED
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
scale_factor
,
)
else
:
...
...
@@ -1107,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass):
# in fallback path, when we don't use flashinfer
fuse_rms_quant
=
config
.
compilation_config
.
pass_config
.
enable_fusion
)
self
.
register_patterns
()
@
enable_fake_mode
def
register_patterns
(
self
):
for
epsilon
in
[
1e-5
,
1e-6
]:
AllReduceFusedRMSNormStaticQuantFP8Pattern
(
epsilon
,
...
...
vllm/compilation/cuda_graph.py
View file @
a99300bd
...
...
@@ -67,11 +67,9 @@ class CUDAGraphWrapper:
runnable
:
Callable
,
vllm_config
:
VllmConfig
,
runtime_mode
:
CUDAGraphMode
,
graph_pool
:
Any
=
None
,
cudagraph_options
:
Optional
[
CUDAGraphOptions
]
=
None
):
self
.
runnable
=
runnable
self
.
vllm_config
=
vllm_config
self
.
graph_pool
=
graph_pool
self
.
runtime_mode
=
runtime_mode
self
.
compilation_config
=
vllm_config
.
compilation_config
...
...
@@ -81,8 +79,10 @@ class CUDAGraphWrapper:
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert
self
.
runtime_mode
!=
CUDAGraphMode
.
NONE
if
self
.
graph_pool
is
None
:
self
.
graph_pool
=
current_platform
.
get_global_graph_pool
()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self
.
graph_pool
=
current_platform
.
get_global_graph_pool
()
if
cudagraph_options
is
None
:
cudagraph_options
=
CUDAGraphOptions
()
...
...
vllm/compilation/decorators.py
View file @
a99300bd
...
...
@@ -54,6 +54,14 @@ def _should_ignore_torch_compile(cls) -> bool:
return
getattr
(
cls
,
IGNORE_COMPILE_KEY
,
False
)
@
overload
def
support_torch_compile
(
*
,
enable_if
:
Optional
[
Callable
[[
VllmConfig
],
bool
]]
=
None
,
)
->
Callable
[[
_T
],
_T
]:
...
@
overload
def
support_torch_compile
(
*
,
...
...
@@ -71,6 +79,7 @@ def support_torch_compile(
cls
:
Optional
[
_T
]
=
None
,
*
,
dynamic_arg_dims
:
Optional
[
dict
[
str
,
Union
[
int
,
list
[
int
]]]]
=
None
,
enable_if
:
Optional
[
Callable
[[
VllmConfig
],
bool
]]
=
None
,
)
->
Union
[
Callable
[[
_T
],
_T
],
_T
]:
"""
A decorator to add support for compiling the forward method of a class.
...
...
@@ -120,6 +129,11 @@ def support_torch_compile(
NOTE: if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
`enable_if` is a function that takes a `VllmConfig` object as input and
returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain
conditions are met.
"""
def
cls_decorator_helper
(
cls
:
_T
)
->
_T
:
...
...
@@ -151,7 +165,8 @@ def support_torch_compile(
if
k
not
in
sig
.
parameters
:
raise
ValueError
(
f
"Argument
{
k
}
not found in the forward method of
{
cls
}
"
)
return
_support_torch_compile
(
cls
,
inferred_dynamic_arg_dims
)
return
_support_torch_compile
(
cls
,
inferred_dynamic_arg_dims
,
enable_if
)
if
cls
is
not
None
:
# use `support_torch_compile` as a decorator without arguments
...
...
@@ -164,6 +179,7 @@ def support_torch_compile(
def
_support_torch_compile
(
cls
:
_T
,
dynamic_arg_dims
:
dict
[
str
,
Union
[
int
,
list
[
int
]]],
enable_if
:
Optional
[
Callable
[[
VllmConfig
],
bool
]]
=
None
,
)
->
_T
:
"""
A decorator to add support for compiling the forward method of a class.
...
...
@@ -184,13 +200,14 @@ def _support_torch_compile(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
):
old_init
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
vllm_config
=
vllm_config
enable_compile
=
enable_if
is
None
or
enable_if
(
vllm_config
)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self
.
do_not_compile
=
\
vllm_config
.
compilation_config
.
level
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
]
or
not
supports_dynamo
()
or
_should_ignore_torch_compile
(
self
.
__class__
)
self
.
__class__
)
or
not
enable_compile
if
self
.
do_not_compile
:
return
...
...
@@ -273,8 +290,24 @@ def _support_torch_compile(
code
.
co_filename
)
return
inline_call
(
parent
,
func
,
args
,
kwargs
)
# Disable the C++ compilation of symbolic shape guards. C++-fication
# of symbolic shape guards can improve guard overhead. But, since
# vllm skip guards anyways, setting this flag to False can improve
# compile time.
dynamo_config_patches
=
{}
try
:
_
=
torch
.
_dynamo
.
config
.
enable_cpp_symbolic_shape_guards
dynamo_config_patches
[
"enable_cpp_symbolic_shape_guards"
]
=
False
except
AttributeError
:
# Note: this config is not available in torch 2.6, we can skip
# if the config doesn't exist
logger
.
debug
(
"enable_cpp_symbolic_shape_guards config not available"
)
with
patch
.
object
(
InliningInstructionTranslator
,
'inline_call'
,
patched_inline_call
):
patched_inline_call
),
torch
.
_dynamo
.
config
.
patch
(
**
dynamo_config_patches
):
output
=
self
.
compiled_callable
(
*
args
,
**
kwargs
)
return
output
...
...
vllm/compilation/fix_functionalization.py
View file @
a99300bd
...
...
@@ -9,6 +9,7 @@ import torch
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
.fx_utils
import
is_func
from
.vllm_inductor_pass
import
VllmInductorPass
...
...
@@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if
current_platform
.
is_xpu
():
logger
.
debug
(
"XPU platform does not support fix functionalization"
"pass currently."
)
return
self
.
begin
()
self
.
dump_graph
(
graph
,
"before_fix_functionalization"
)
...
...
@@ -89,6 +97,15 @@ class FixFunctionalizationPass(VllmInductorPass):
# node,
# mutated_args,
# args=('result', 'input', 'scale'))
# elif hasattr(
# torch.ops._C, "silu_and_mul_nvfp4_quant"
# ) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
# mutated_args = {1: 'result', 2: 'result_block_scale'}
# self.defunctionalize(graph,
# node,
# mutated_args,
# args=('result', 'result_block_scale',
# 'input', 'input_global_scale'))
else
:
continue
# skip the count
...
...
vllm/compilation/fusion.py
View file @
a99300bd
...
...
@@ -12,15 +12,18 @@ from torch._ops import OpOverload
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
)
GroupShape
,
QuantKey
,
ScaleDesc
,
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
kNvfp4Quant
,
kStaticTensorScale
)
from
vllm.platforms
import
current_platform
from
.fx_utils
import
find_getitem_maybe
from
.inductor_pass
import
enable_fake_mode
from
.multi_output_match
import
MultiOutputMatch
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
def
empty_bf16
(
*
args
,
**
kwargs
):
...
...
@@ -31,41 +34,12 @@ def empty_fp32(*args, **kwargs):
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
RMS_ADD_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
class
QuantKey
(
NamedTuple
):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
def
empty_i32
(
*
args
,
**
kwargs
):
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
"""
dtype
:
torch
.
dtype
static
:
bool
group_shape
:
GroupShape
symmetric
:
bool
=
True
def
__str__
(
self
):
group_shape
=
(
'per_tensor'
if
self
.
group_shape
==
GroupShape
.
PER_TENSOR
else
(
'per_token'
if
self
.
group_shape
==
GroupShape
.
PER_TOKEN
else
str
(
self
.
group_shape
)))
return
(
f
"QuantKey(
{
'static'
if
self
.
static
else
'dynamic'
}
,"
f
"
{
fx
.
graph
.
dtype_abbrs
[
self
.
dtype
]
}
,
{
group_shape
}
,"
f
"
{
'a'
if
not
self
.
symmetric
else
''
}
symmetric)"
)
# kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
# kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
# kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
RMS_ADD_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
QUANT_OPS
:
dict
[
QuantKey
,
OpOverload
]
=
{
# kFp8StaticTensorSym:
...
...
@@ -75,6 +49,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8DynamicTokenSym:
# torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
QUANT_OPS
[
kNvfp4Quant
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
default
# noqa: E501
class
FusedRMSQuantKey
(
NamedTuple
):
...
...
@@ -187,11 +164,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
symmetric
=
True
):
fused_key
=
FusedRMSQuantKey
(
fused_add
=
False
,
quant
=
QuantKey
(
dtype
=
quant_dtype
,
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
symmetric
=
symmetric
))
quant
=
QuantKey
(
dtype
=
quant_dtype
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
))
super
().
__init__
(
epsilon
,
fused_key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
...
...
@@ -244,11 +219,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
symmetric
=
True
):
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
quant
=
QuantKey
(
dtype
=
quant_dtype
,
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
symmetric
=
symmetric
))
quant
=
QuantKey
(
dtype
=
quant_dtype
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
))
super
().
__init__
(
epsilon
,
key
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
,
...
...
@@ -337,10 +310,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
):
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
False
,
quant
=
QuantKey
(
dtype
=
quant_dtype
,
static
=
False
,
group_shape
=
group_shape
,
scale
=
scale
,
symmetric
=
symmetric
))
super
().
__init__
(
epsilon
,
key
)
...
...
@@ -435,10 +408,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
=
GroupShape
.
PER_TOKEN
,
symmetric
=
True
):
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
fused_add
=
True
,
quant
=
QuantKey
(
dtype
=
quant_dtype
,
static
=
False
,
group_shape
=
group_shape
,
scale
=
scale
,
symmetric
=
symmetric
))
super
().
__init__
(
epsilon
,
key
)
...
...
@@ -556,6 +529,7 @@ class FusionPass(VllmInductorPass):
cls
.
_instance
.
pass_config
=
config
.
compilation_config
.
pass_config
return
cls
.
_instance
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
assert
self
.
__class__
.
_instance
is
None
,
\
"FusionPass singleton instance already exists"
...
...
vllm/compilation/fusion_attn.py
View file @
a99300bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
import
torch
import
torch._inductor.pattern_matcher
as
pm
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._inductor.pattern_matcher
import
PatternMatcherPass
from
torch._subclasses.fake_tensor
import
(
FakeTensorMode
,
unset_fake_temporarily
)
from
vllm.attention
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kNvfp4Quant
,
kStaticTensorScale
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
round_up
from
.fusion
import
QUANT_OPS
,
GroupShape
,
QuantKey
,
empty_bf16
,
empty_fp32
from
.fusion
import
QUANT_OPS
,
empty_bf16
,
empty_fp32
,
empty_i32
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
ATTN_OP
=
torch
.
ops
.
vllm
.
unified_attention_with_output
.
default
RESHAPE_OP
=
torch
.
ops
.
aten
.
reshape
.
default
class
AttentionStaticQuantPattern
:
class
AttentionQuantPattern
(
ABC
):
"""
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def
__init__
(
self
,
layer_name
:
str
,
num_heads
:
int
,
head_size
:
int
,
quant_dtype
:
torch
.
dtype
,
symmetric
=
True
,
layer
:
Attention
,
quant_key
:
QuantKey
,
):
self
.
layer_name
=
layer_name
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
quant_dtype
=
quant_dtype
self
.
quant_key
=
QuantKey
(
dtype
=
quant_dtype
,
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
symmetric
=
symmetric
)
self
.
layer
=
layer
self
.
layer_name
=
layer
.
layer_name
self
.
num_heads
=
layer
.
num_heads
self
.
head_size
=
layer
.
head_size
self
.
quant_key
=
quant_key
self
.
quant_dtype
=
quant_key
.
dtype
assert
self
.
quant_key
in
QUANT_OPS
,
\
f
"unsupported quantization scheme
{
self
.
quant_key
}
"
self
.
QUANT_OP
=
QUANT_OPS
[
self
.
quant_key
]
...
...
@@ -48,31 +55,64 @@ class AttentionStaticQuantPattern:
kwargs
=
{
'dtype'
:
self
.
quant_dtype
,
'device'
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
def
register_if_supported
(
self
,
pm_pass
:
PatternMatcherPass
,
layer
:
Attention
):
if
layer
.
impl
.
fused_output_quant_supported
(
self
.
quant_dtype
,
self
.
quant_key
.
static
,
self
.
quant_key
.
group_shape
):
@
staticmethod
def
wrap_trace_fn
(
process_fx
,
trace_fn
):
def
wrapped
(
*
args
,
**
kwargs
):
return
process_fx
(
trace_fn
(
*
args
,
**
kwargs
))
return
wrapped
@
staticmethod
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
):
from
torch._inductor.fx_passes.post_grad
import
view_to_reshape
view_to_reshape
(
gm
)
return
gm
def
register_if_supported
(
self
,
pm_pass
:
PatternMatcherPass
):
if
self
.
layer
.
impl
.
fused_output_quant_supported
(
self
.
quant_key
):
self
.
_register
(
pm_pass
)
@
abstractmethod
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
raise
NotImplementedError
class
AttentionFp8StaticQuantPattern
(
AttentionQuantPattern
):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def
__init__
(
self
,
layer
:
Attention
,
symmetric
:
bool
=
True
,
):
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
)
super
().
__init__
(
layer
,
quant_key
)
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
view_7
=
RESHAPE_OP
(
output_attn
,
[
-
1
,
self
.
num_heads
,
self
.
head_size
])
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
view_7
,
output
=
output_attn
,
layer_name
=
self
.
layer_name
,
output_scale
=
None
)
attn_out_view
=
RESHAPE_OP
(
at1
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
])
output_scale
=
None
,
output_block_scale
=
None
)
attn_out_view
=
RESHAPE_OP
(
at1
[
1
],
[
q
.
shape
[
0
],
self
.
num_heads
*
self
.
head_size
])
at2
=
auto_functionalized
(
self
.
QUANT_OP
,
result
=
output_quant
,
input
=
attn_out_view
,
...
...
@@ -82,47 +122,116 @@ class AttentionStaticQuantPattern:
def
replacement
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
):
view_7
=
RESHAPE_OP
(
output_quant
,
[
-
1
,
self
.
num_heads
,
self
.
head_size
])
# attn output in quant_dtype
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
[
q
.
shape
[
0
],
self
.
num_heads
,
self
.
head_size
],
0.0
,
dtype
=
self
.
quant_dtype
,
device
=
q
.
device
)
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
view_7
,
output
=
output_attn
,
layer_name
=
self
.
layer_name
,
output_scale
=
scale
)
output_scale
=
scale
,
output_block_scale
=
None
)
return
RESHAPE_OP
(
at1
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
])
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with
unset_fake_temporarily
(),
FakeTensorMode
():
inputs
=
[
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# q
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# k
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# v
empty_bf16
(
5
,
self
.
num_heads
*
self
.
head_size
),
# attn_output
self
.
empty_quant
(
5
,
self
.
num_heads
*
self
.
head_size
),
# quant_output
empty_fp32
(
1
,
1
)
# scale
]
inputs
=
[
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# q
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# k
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# v
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# attn_output
self
.
empty_quant
(
5
,
self
.
num_heads
*
self
.
head_size
),
# quant_output
empty_fp32
(
1
,
1
)
# scale
]
def
wrap_trace_fn
(
process_fx
,
trace_fn
):
pm
.
register_replacement
(
pattern
,
replacement
,
inputs
,
AttentionQuantPattern
.
wrap_trace_fn
(
AttentionQuantPattern
.
fx_view_to_reshape
,
pm
.
fwd_only
),
pm_pass
)
def
wrapped
(
*
args
,
**
kwargs
):
return
process_fx
(
trace_fn
(
*
args
,
**
kwargs
))
return
wrapped
class
AttentionNvfp4QuantPattern
(
AttentionQuantPattern
):
"""
Fusion for Attention+Nvfp4Quant.
def
fx_view_to_reshape
(
gm
:
torch
.
fx
.
GraphModule
):
from
torch._inductor.fx_passes.post_grad
import
view_to_reshape
view_to_reshape
(
gm
)
return
gm
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Nvfp4Quant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def
__init__
(
self
,
layer
:
Attention
):
super
().
__init__
(
layer
,
kNvfp4Quant
)
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
):
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
self
.
layer_name
,
output_scale
=
None
,
output_block_scale
=
None
)
attn_out_view
=
RESHAPE_OP
(
at1
[
1
],
[
q
.
shape
[
0
],
self
.
num_heads
*
self
.
head_size
])
at2
=
auto_functionalized
(
self
.
QUANT_OP
,
output
=
output_quant
,
input
=
attn_out_view
,
output_scale
=
output_scale
,
input_scale
=
input_scale
)
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
at2
[
2
],
FP8_DTYPE
)
return
at2
[
1
],
output_scale_view
pm
.
register_replacement
(
pattern
,
replacement
,
inputs
,
wrap_trace_fn
(
fx_view_to_reshape
,
pm
.
fwd_only
),
pm_pass
)
def
replacement
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
):
# attention output in quant_dtype
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
[
q
.
shape
[
0
],
self
.
num_heads
,
self
.
head_size
//
2
],
0.0
,
dtype
=
self
.
quant_dtype
,
device
=
q
.
device
)
# attention output block scale
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
output_scale
,
FP8_DTYPE
)
at2
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
self
.
layer_name
,
output_scale
=
input_scale
,
output_block_scale
=
output_scale_view
)
output
=
RESHAPE_OP
(
at2
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
//
2
])
return
output
,
at2
[
2
]
inputs
=
[
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# q
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# k
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# v
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# output_attn
self
.
empty_quant
(
5
,
self
.
num_heads
*
self
.
head_size
//
2
),
# output_quant
empty_i32
(
128
,
round_up
(
self
.
num_heads
*
self
.
head_size
//
16
,
4
)),
# output_scale
empty_fp32
(
1
,
1
),
# input_scale
]
pm
.
register_replacement
(
pattern
,
replacement
,
inputs
,
AttentionQuantPattern
.
wrap_trace_fn
(
AttentionQuantPattern
.
fx_view_to_reshape
,
pm
.
fwd_only
),
pm_pass
)
class
AttnFusionPass
(
VllmInductorPass
):
...
...
@@ -138,32 +247,42 @@ class AttnFusionPass(VllmInductorPass):
support are attention kernels, which need to support fusing output quant.
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
self
.
static_fwd_ctx
=
config
.
compilation_config
.
static_forward_context
self
.
patterns
=
PatternMatcherPass
(
pass_name
=
"attn_fusion_pass"
)
for
key
,
layer
in
self
.
static_fwd_ctx
.
items
():
pattern
=
AttentionStaticQuantPattern
(
key
,
layer
.
num_heads
,
layer
.
head_size
,
current_platform
.
fp8_dtype
())
pattern
.
register_if_supported
(
self
.
patterns
,
layer
)
if
len
(
self
.
static_fwd_ctx
)
==
0
:
attn_layers
=
get_layers_from_vllm_config
(
config
,
Attention
)
for
layer_name
,
layer
in
attn_layers
.
items
():
pattern_fp8
=
AttentionFp8StaticQuantPattern
(
layer
)
pattern_fp8
.
register_if_supported
(
self
.
patterns
)
pattern_nvfp4
=
AttentionNvfp4QuantPattern
(
layer
)
pattern_nvfp4
.
register_if_supported
(
self
.
patterns
)
if
len
(
attn_layers
)
==
0
:
logger
.
warning
(
"Attention + quant fusion is enabled, but "
"CompilationConfig.static_forward_context is empty. "
"Cannot access attention layers so no fusion "
"patterns were registered."
)
"Attention + quant fusion is enabled, but no attention layers "
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered."
)
def
__call__
(
self
,
graph
:
torch
.
fx
.
graph
.
Graph
)
->
None
:
self
.
begin
()
self
.
dump_graph
(
graph
,
"before_attn_fusion"
)
count
=
self
.
patterns
.
apply
(
graph
)
# TODO: Move this to pass_manager.py after the fx graph broken issue
# has been resolved.
# see https://github.com/vllm-project/vllm/issues/23091
graph
.
eliminate_dead_code
()
logger
.
debug
(
"Fused quantization onto %s attention nodes"
,
count
)
self
.
dump_graph
(
graph
,
"after_attn_fusion"
)
self
.
end_and_log
()
def
uuid
(
self
):
return
VllmInductorPass
.
hash_source
(
self
,
AttentionStaticQuantPattern
)
return
VllmInductorPass
.
hash_source
(
self
,
AttentionQuantPattern
,
AttentionFp8StaticQuantPattern
,
AttentionNvfp4QuantPattern
)
vllm/compilation/inductor_pass.py
View file @
a99300bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
hashlib
import
inspect
import
json
...
...
@@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union
import
torch
from
torch
import
fx
from
torch._subclasses.fake_tensor
import
(
FakeTensorMode
,
unset_fake_temporarily
)
from
vllm.utils
import
is_torch_equal_or_newer
...
...
@@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass):
def
uuid
(
self
)
->
Any
:
return
self
.
_uuid
def
enable_fake_mode
(
fn
:
Callable
[...,
Any
])
->
Callable
[...,
Any
]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@
functools
.
wraps
(
fn
)
def
fn_new
(
*
args
,
**
kwargs
)
->
Any
:
with
torch
.
_guards
.
tracing
(
None
),
unset_fake_temporarily
(),
FakeTensorMode
():
result
=
fn
(
*
args
,
**
kwargs
)
return
result
return
fn_new
vllm/compilation/monitor.py
View file @
a99300bd
...
...
@@ -43,7 +43,7 @@ cudagraph_capturing_enabled: bool = True
def
validate_cudagraph_capturing_enabled
():
# used to monitor whether a
n
cudagraph capturing is legal at runtime.
# used to monitor whether a cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
global
cudagraph_capturing_enabled
...
...
vllm/compilation/pass_manager.py
View file @
a99300bd
...
...
@@ -8,13 +8,13 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cuda_alike
():
from
.activation_quant_fusion
import
ActivationQuantFusionPass
from
.fusion
import
FusionPass
from
.fusion_attn
import
AttnFusionPass
if
current_platform
.
is_cuda
():
from
.collective_fusion
import
AllReduceFusionPass
,
AsyncTPPass
# from .activation_quant_fusion import ActivationQuantFusionPass
from
.fix_functionalization
import
FixFunctionalizationPass
from
.inductor_pass
import
CustomGraphPass
,
InductorPass
,
get_pass_context
from
.noop_elimination
import
NoOpEliminationPass
...
...
vllm/compilation/sequence_parallelism.py
View file @
a99300bd
...
...
@@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import (
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
...
...
@@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass):
performance.
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
...
...
vllm/config/__init__.py
View file @
a99300bd
...
...
@@ -36,7 +36,8 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo
)
from
vllm.config.compilation
import
(
CompilationConfig
,
CompilationLevel
,
CUDAGraphMode
,
PassConfig
)
from
vllm.config.parallel
import
DistributedExecutorBackend
,
ParallelConfig
from
vllm.config.parallel
import
(
DistributedExecutorBackend
,
EPLBConfig
,
ParallelConfig
)
from
vllm.config.scheduler
import
SchedulerConfig
,
SchedulerPolicy
from
vllm.config.utils
import
ConfigType
,
config
from
vllm.logger
import
init_logger
...
...
@@ -199,7 +200,17 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
yield
a
,
b
a
=
b
cls_node
=
ast
.
parse
(
textwrap
.
dedent
(
inspect
.
getsource
(
cls
))).
body
[
0
]
try
:
cls_node
=
ast
.
parse
(
textwrap
.
dedent
(
inspect
.
getsource
(
cls
))).
body
[
0
]
except
(
OSError
,
KeyError
,
TypeError
):
# HACK: Python 3.13+ workaround - set missing __firstlineno__
# Workaround can be removed after we upgrade to pydantic==2.12.0
with
open
(
inspect
.
getfile
(
cls
))
as
f
:
for
i
,
line
in
enumerate
(
f
):
if
f
"class
{
cls
.
__name__
}
"
in
line
and
":"
in
line
:
cls
.
__firstlineno__
=
i
+
1
break
cls_node
=
ast
.
parse
(
textwrap
.
dedent
(
inspect
.
getsource
(
cls
))).
body
[
0
]
if
not
isinstance
(
cls_node
,
ast
.
ClassDef
):
raise
TypeError
(
"Given object was not a class."
)
...
...
@@ -254,8 +265,14 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode
=
Literal
[
"auto"
,
"cpm"
,
"slow"
,
"mistral"
,
"custom"
]
ModelDType
=
Literal
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
]
LogprobsMode
=
Literal
[
"raw_logprobs"
,
"raw_logits"
,
"processed_logprobs"
,
"processed_logits"
]
MMEncoderTPMode
=
Literal
[
"weights"
,
"data"
]
class
LogprobsMode
(
enum
.
Enum
):
RAW_LOGITS
=
"raw_logits"
RAW_LOGPROBS
=
"raw_logprobs"
PROCESSED_LOGITS
=
"processed_logits"
PROCESSED_LOGPROBS
=
"processed_logprobs"
@
config
...
...
@@ -359,12 +376,13 @@ class ModelConfig:
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
logprobs_mode
:
LogprobsMode
=
"raw_l
ogprobs
"
logprobs_mode
:
LogprobsMode
=
L
ogprobs
Mode
.
RAW_LOGPROBS
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
Raw means the values before applying any logit processors, like bad words.
Processed means the values after applying all processors, including
temperature and top_k/top_p.
"""
disable_sliding_window
:
bool
=
False
"""Whether to disable sliding window. If True, we will disable the sliding
...
...
@@ -427,7 +445,7 @@ class ModelConfig:
from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
"""
mm_processor_cache_gb
:
in
t
=
4
mm_processor_cache_gb
:
floa
t
=
4
"""The size (in GiB) of the multi-modal processor cache, which is used to
avoid re-processing past multi-modal inputs.
...
...
@@ -436,6 +454,19 @@ class ModelConfig:
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
Set to `0` to disable this cache completely (not recommended)."""
mm_encoder_tp_mode
:
MMEncoderTPMode
=
"weights"
"""Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
override_neuron_config
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
...
...
@@ -478,6 +509,8 @@ class ModelConfig:
logits_processors
:
Optional
[
list
[
Union
[
str
,
type
[
LogitsProcessor
]]]]
=
None
"""One or more logits processors' fully-qualified class names or class
definitions"""
io_processor_plugin
:
Optional
[
str
]
=
None
"""IOProcessor plugin name to load at model startup"""
enable_chunked_prefill
:
Optional
[
bool
]
=
None
"""If True, prefill requests can be chunked based
...
...
@@ -854,22 +887,25 @@ class ModelConfig:
def
_init_multimodal_config
(
self
)
->
Optional
[
"MultiModalConfig"
]:
if
self
.
_model_info
.
supports_multimodal
:
if
(
self
.
mm_encoder_tp_mode
==
"data"
and
not
self
.
_model_info
.
supports_multimodal_encoder_tp_data
):
logger
.
warning_once
(
"This model does not support `--mm-encoder-tp-mode data`. "
"Falling back to `--mm-encoder-tp-mode weights`."
)
self
.
mm_encoder_tp_mode
=
"weights"
return
MultiModalConfig
(
limit_per_prompt
=
self
.
limit_mm_per_prompt
,
media_io_kwargs
=
self
.
media_io_kwargs
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
mm_processor_cache_gb
=
self
.
mm_processor_cache_gb
,
mm_encoder_tp_mode
=
self
.
mm_encoder_tp_mode
,
interleave_mm_strings
=
self
.
interleave_mm_strings
,
skip_mm_profiling
=
self
.
skip_mm_profiling
)
skip_mm_profiling
=
self
.
skip_mm_profiling
,
)
return
None
def
set_mm_processor_cache_gb
(
self
,
value
:
int
)
->
None
:
mm_config
=
self
.
get_multimodal_config
()
self
.
mm_processor_cache_gb
=
value
mm_config
.
mm_processor_cache_gb
=
value
def
_get_encoder_config
(
self
):
return
get_sentence_transformer_tokenizer_config
(
self
.
model
,
self
.
revision
)
...
...
@@ -1099,10 +1135,22 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
me_quant
.
QUANTIZATION_METHODS
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed-tensors"
,
"experts_int8"
,
"quark"
,
"modelopt_fp4"
,
"bitblas"
,
"gptq_bitblas"
,
"inc"
,
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
"fp8"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed-tensors"
,
"experts_int8"
,
"quark"
,
"modelopt_fp4"
,
"bitblas"
,
"gptq_bitblas"
,
"inc"
,
"petit_nvfp4"
,
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
cast
(
me_quant
.
QuantizationMethods
,
...
...
@@ -1125,7 +1173,6 @@ class ModelConfig:
# `override_quantization_method` method) must be checked in order
# of preference (this is particularly important for GPTQ).
overrides
=
[
"marlin"
,
"bitblas"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
...
...
@@ -1136,6 +1183,7 @@ class ModelConfig:
"slimquant_w4a8_marlin"
"modelopt"
,
"modelopt_fp4"
,
"petit_nvfp4"
,
]
quantization_methods
=
[
q
for
q
in
supported_quantization
if
q
not
in
overrides
...
...
@@ -1470,7 +1518,8 @@ class ModelConfig:
from
vllm.distributed.utils
import
get_pp_indices
if
(
self
.
hf_text_config
.
model_type
==
"deepseek_mtp"
or
self
.
hf_config
.
model_type
==
"mimo_mtp"
or
self
.
hf_config
.
model_type
==
"glm4_moe_mtp"
):
or
self
.
hf_config
.
model_type
==
"glm4_moe_mtp"
or
self
.
hf_config
.
model_type
==
"ernie_mtp"
):
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
0
)
else
:
...
...
@@ -1670,29 +1719,8 @@ class ModelConfig:
return
self
.
multimodal_config
is
not
None
@
property
def
processor_return_mm_hashes
(
self
)
->
bool
:
"""Whether the multi-modal processor should output hashes."""
mm_config
=
self
.
multimodal_config
if
mm_config
is
None
:
return
False
return
mm_config
.
mm_processor_cache_gb
>
0
@
property
def
enable_mm_processor_cache
(
self
)
->
bool
:
"""Whether the multi-modal processor cache should be enabled."""
mm_config
=
self
.
multimodal_config
if
mm_config
is
None
:
return
False
return
mm_config
.
mm_processor_cache_gb
>
0
def
get_mm_input_cache_gb
(
self
)
->
int
:
mm_config
=
self
.
multimodal_config
if
mm_config
is
None
:
return
0
return
envs
.
VLLM_MM_INPUT_CACHE_GIB
def
is_multimodal_raw_input_only_model
(
self
)
->
bool
:
return
self
.
_model_info
.
supports_multimodal_raw_input_only
@
property
def
is_cross_encoder
(
self
)
->
bool
:
...
...
@@ -1703,10 +1731,6 @@ class ModelConfig:
def
is_pp_supported
(
self
)
->
bool
:
return
self
.
_model_info
.
supports_pp
@
property
def
is_multimodal_raw_input_supported
(
self
)
->
bool
:
return
self
.
_model_info
.
supports_multimodal_raw_input
@
property
def
is_attention_free
(
self
)
->
bool
:
return
self
.
_model_info
.
is_attention_free
...
...
@@ -1917,7 +1941,8 @@ class DeviceConfig:
SpeculativeMethod
=
Literal
[
"ngram"
,
"eagle"
,
"eagle3"
,
"medusa"
,
"mlp_speculator"
,
"draft_model"
,
"deepseek_mtp"
]
"mlp_speculator"
,
"draft_model"
,
"deepseek_mtp"
,
"ernie_mtp"
]
@
config
...
...
@@ -2050,6 +2075,16 @@ class SpeculativeConfig:
"architectures"
:
[
"Glm4MoeMTPModel"
]
})
if
hf_config
.
model_type
==
"ernie4_5_moe"
:
hf_config
.
model_type
=
"ernie_mtp"
if
hf_config
.
model_type
==
"ernie_mtp"
:
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
None
)
hf_config
.
update
({
"n_predict"
:
n_predict
,
"architectures"
:
[
"ErnieMTPModel"
]
})
return
hf_config
return
hf_config
def
__post_init__
(
self
):
...
...
@@ -2068,8 +2103,8 @@ class SpeculativeConfig:
if
self
.
target_model_config
and
\
(
self
.
target_model_config
.
hf_text_config
.
model_type
\
==
"deepseek_v3"
or
self
.
target_model_config
.
hf_text_config
.
model_type
\
==
"mimo"
):
self
.
target_model_config
.
hf_text_config
.
model_type
in
(
"mimo"
,
"ernie4_5_moe"
)
):
# use the draft model from the same model:
self
.
model
=
self
.
target_model_config
.
model
elif
self
.
method
in
(
"ngram"
,
"[ngram]"
):
...
...
@@ -2167,6 +2202,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes "
\
"to support multiple layers."
)
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
==
"ernie_mtp"
):
self
.
method
=
"ernie_mtp"
if
self
.
num_speculative_tokens
>
1
:
logger
.
warning
(
"All Ernie MTP models only have "
\
"one layer. Might need some code changes "
\
"to support multiple layers."
)
else
:
self
.
method
=
"draft_model"
raise
NotImplementedError
(
...
...
@@ -2386,7 +2430,7 @@ class SpeculativeConfig:
return
self
.
num_speculative_tokens
def
use_eagle
(
self
)
->
bool
:
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"deepseek_mtp"
)
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"deepseek_mtp"
,
"ernie_mtp"
)
def
__repr__
(
self
)
->
str
:
method
=
self
.
method
...
...
@@ -2422,8 +2466,8 @@ class LoRAConfig:
lora_dtype
:
Union
[
torch
.
dtype
,
LoRADType
]
=
"auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
lora_extra_vocab_size
:
int
=
256
"""Maximum size of extra vocabulary that can be present in a
LoRA adapter
(added to the base model vocabulary)
."""
"""
(Deprecated)
Maximum size of extra vocabulary that can be present in a
LoRA adapter. Will be removed in v0.12.0
."""
lora_vocab_padding_size
:
ClassVar
[
int
]
=
current_platform
\
.
get_lora_vocab_padding_size
()
...
...
@@ -2465,6 +2509,12 @@ class LoRAConfig:
return
hash_str
def
__post_init__
(
self
):
# Deprecation warning for lora_extra_vocab_size
logger
.
warning
(
"`lora_extra_vocab_size` is deprecated and will be removed "
"in v0.12.0. Additional vocabulary support for "
"LoRA adapters is being phased out."
)
# Setting the maximum rank to 512 should be able to satisfy the vast
# majority of applications.
possible_max_ranks
=
(
8
,
16
,
32
,
64
,
128
,
256
,
320
,
512
)
...
...
@@ -2529,7 +2579,7 @@ class MultiModalConfig:
`{"num_crops": 4}`.
"""
mm_processor_cache_gb
:
in
t
=
4
mm_processor_cache_gb
:
floa
t
=
4
"""
The size (in GiB) of the multi-modal processor cache, which is used to
...
...
@@ -2540,6 +2590,22 @@ class MultiModalConfig:
Set to `0` to disable this cache completely (not recommended).
"""
mm_encoder_tp_mode
:
MMEncoderTPMode
=
"weights"
"""
Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP.
"""
interleave_mm_strings
:
bool
=
False
"""
Enable fully interleaved support for multimodal prompts.
...
...
@@ -2547,7 +2613,7 @@ class MultiModalConfig:
skip_mm_profiling
:
bool
=
False
"""
When enabled, skips multimodal memory profiling and only profiles with
When enabled, skips multimodal memory profiling and only profiles with
language backbone model during engine initialization.
This reduces engine startup time but shifts the responsibility to users for
...
...
@@ -2610,24 +2676,24 @@ class PoolerConfig:
## for embeddings models
normalize
:
Optional
[
bool
]
=
None
"""
Whether to normalize the embeddings outputs.
Whether to normalize the embeddings outputs.
"""
dimensions
:
Optional
[
int
]
=
None
"""
Reduce the dimensions of embeddings if model
Reduce the dimensions of embeddings if model
support matryoshka representation.
"""
## for classification models
activation
:
Optional
[
bool
]
=
None
"""
Whether to apply activation function to the classification outputs.
Whether to apply activation function to the classification outputs.
"""
## for reward models
softmax
:
Optional
[
bool
]
=
None
"""
Whether to apply softmax to the reward outputs.
Whether to apply softmax to the reward outputs.
"""
step_tag_id
:
Optional
[
int
]
=
None
"""
...
...
@@ -2653,9 +2719,9 @@ class PoolerConfig:
max_embed_len
:
Optional
[
int
]
=
None
"""
Maximum input length allowed for embedding generation. When set, allows
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
This parameter enables accepting long inputs without requiring
This parameter enables accepting long inputs without requiring
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
max_embed_len, it will be handled according to the original max_model_len
validation logic. Defaults to None (i.e. set to max_model_len).
...
...
@@ -3009,7 +3075,8 @@ def get_served_model_name(model: str,
return
served_model_name
GuidedDecodingBackend
=
Literal
[
"auto"
,
"xgrammar"
,
"guidance"
,
"outlines"
]
GuidedDecodingBackend
=
Literal
[
"auto"
,
"xgrammar"
,
"guidance"
,
"outlines"
,
"lm-format-enforcer"
]
@
config
...
...
@@ -3572,7 +3639,7 @@ class VllmConfig:
if
self
.
compilation_config
.
pass_config
.
enable_sequence_parallelism
:
self
.
compilation_config
.
custom_ops
.
append
(
"+rms_norm"
)
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_xpu
()
:
# if cudagraph_mode is not explicitly set by users, set default
# value
if
self
.
compilation_config
.
cudagraph_mode
is
None
:
...
...
vllm/config/cache.py
View file @
a99300bd
...
...
@@ -115,8 +115,8 @@ class CacheConfig:
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
some layers can skip tokens corresponding to prefill. This flag enables
attention metadata for eligible layers to be overriden with metadata
necessary for implement
at
ing this optimization in some models (e.g. Gemma3n)
attention metadata for eligible layers to be overrid
d
en with metadata
necessary for implementing this optimization in some models (e.g. Gemma3n)
"""
def
compute_hash
(
self
)
->
str
:
...
...
@@ -145,12 +145,19 @@ class CacheConfig:
self
.
_verify_cache_dtype
()
self
.
_verify_prefix_caching
()
self
.
_verify_kv_sharing_fast_prefill
()
def
metrics_info
(
self
):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return
{
key
:
str
(
value
)
for
key
,
value
in
self
.
__dict__
.
items
()}
def
_verify_kv_sharing_fast_prefill
(
self
)
->
None
:
if
self
.
kv_sharing_fast_prefill
and
not
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
"Fast prefill optimization for KV sharing is not supported "
"in V0 currently."
)
@
model_validator
(
mode
=
'after'
)
def
_verify_args
(
self
)
->
Self
:
if
self
.
cpu_offload_gb
<
0
:
...
...
@@ -162,11 +169,6 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f
"
{
self
.
gpu_memory_utilization
}
."
)
if
self
.
kv_sharing_fast_prefill
:
logger
.
warning_once
(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)"
)
return
self
def
_verify_cache_dtype
(
self
)
->
None
:
...
...
vllm/config/compilation.py
View file @
a99300bd
...
...
@@ -225,7 +225,8 @@ class CompilationConfig:
# CudaGraph compilation
cudagraph_mode
:
Optional
[
CUDAGraphMode
]
=
None
"""
The mode of the cudagraph.
The mode of the cudagraph:
- NONE, no cudagraph capture.
- PIECEWISE. (v1 default)
- FULL.
...
...
@@ -336,6 +337,9 @@ class CompilationConfig:
"vllm.unified_attention"
,
"vllm.unified_attention_with_output"
,
"vllm.mamba_mixer2"
,
"vllm.mamba_mixer"
,
"vllm.short_conv"
,
"vllm.linear_attention"
,
]
def
compute_hash
(
self
)
->
str
:
...
...
@@ -382,13 +386,10 @@ class CompilationConfig:
if
pass_config_exclude
:
exclude
[
"pass_config"
]
=
pass_config_exclude
# The cast to string is necessary because Pydantic is mocked in docs
# builds and sphinx-argparse doesn't know the return type of decode()
return
str
(
TypeAdapter
(
CompilationConfig
).
dump_json
(
self
,
exclude
=
exclude
,
# type: ignore[arg-type]
exclude_unset
=
True
).
decode
())
return
TypeAdapter
(
CompilationConfig
).
dump_json
(
self
,
exclude
=
exclude
,
# type: ignore[arg-type]
exclude_unset
=
True
).
decode
()
__str__
=
__repr__
...
...
vllm/config/parallel.py
View file @
a99300bd
...
...
@@ -15,7 +15,7 @@ import vllm.envs as envs
from
vllm.config.utils
import
config
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cuda_device_count_stateless
,
get_open_port
from
vllm.utils
import
cuda_device_count_stateless
,
get_open_port
s_list
if
TYPE_CHECKING
:
from
ray.runtime_env
import
RuntimeEnv
...
...
@@ -32,6 +32,31 @@ logger = init_logger(__name__)
DistributedExecutorBackend
=
Literal
[
"ray"
,
"mp"
,
"uni"
,
"external_launcher"
]
@
config
@
dataclass
class
EPLBConfig
:
"""Configuration for Expert Parallel Load Balancing (EP)."""
window_size
:
int
=
1000
"""Window size for expert load recording."""
step_interval
:
int
=
3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `lb_window_size` steps will be used for rearranging experts.
"""
num_redundant_experts
:
int
=
0
"""Number of redundant experts to use for expert parallelism."""
log_balancedness
:
bool
=
False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
@
config
@
dataclass
class
ParallelConfig
:
...
...
@@ -75,22 +100,24 @@ class ParallelConfig:
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb
:
bool
=
False
"""Enable expert parallelism load balancing for MoE layers."""
num_redundant_experts
:
int
=
0
"""Number of redundant experts to use for expert parallelism."""
eplb_window_size
:
int
=
1000
"""Window size for expert load recording."""
eplb_step_interval
:
int
=
3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `eplb_window_size` steps will be used for rearranging experts.
"""
eplb_log_balancedness
:
bool
=
False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
eplb_config
:
EPLBConfig
=
field
(
default_factory
=
EPLBConfig
)
"""Expert parallelism configuration."""
num_redundant_experts
:
Optional
[
int
]
=
None
"""`num_redundant_experts` is deprecated and has been replaced with
`eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
Please use `eplb_config.num_redundant_experts` instead."""
eplb_window_size
:
Optional
[
int
]
=
None
"""`eplb_window_size` is deprecated and has been replaced with
`eplb_config.window_size`. This will be removed in v0.12.0.
Please use `eplb_config.window_size` instead."""
eplb_step_interval
:
Optional
[
int
]
=
None
"""`eplb_step_interval` is deprecated and has been replaced with
`eplb_config.step_interval`. This will be removed in v0.12.0.
Please use `eplb_config.step_interval` instead."""
eplb_log_balancedness
:
Optional
[
bool
]
=
None
"""`eplb_log_balancedness` is deprecated and has been replaced with
`eplb_config.log_balancedness`. This will be removed in v0.12.0.
Please use `eplb_config.log_balancedness` instead."""
max_parallel_loading_workers
:
Optional
[
int
]
=
None
"""Maximum number of parallel loading workers when loading model
...
...
@@ -109,7 +136,8 @@ class ParallelConfig:
placement_group
:
Optional
[
PlacementGroup
]
=
None
"""ray distributed model workers placement group."""
distributed_executor_backend
:
Optional
[
Union
[
DistributedExecutorBackend
,
distributed_executor_backend
:
Optional
[
Union
[
str
,
DistributedExecutorBackend
,
type
[
ExecutorBase
]]]
=
None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
...
...
@@ -137,9 +165,10 @@ class ParallelConfig:
rank
:
int
=
0
"""Global rank in distributed setup."""
enable_multimodal_encoder_data_parallel
:
bool
=
False
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
_data_parallel_master_port_list
:
list
[
int
]
=
field
(
default_factory
=
list
)
"""List of open port auto-queried for data parallel messaging.
Set to be private as it's not intended to be configured by users.
"""
@
property
def
world_size_across_dp
(
self
)
->
int
:
...
...
@@ -153,11 +182,15 @@ class ParallelConfig:
processes that is related to data parallelism,
e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we
increment the port number
each time we need to
initialize a
new process group related to data parallelism.
pop a new port from the prepared port list
each time we need to
initialize a
new process group related to data parallelism.
"""
answer
=
self
.
data_parallel_master_port
self
.
data_parallel_master_port
+=
1
if
self
.
_data_parallel_master_port_list
:
answer
=
self
.
_data_parallel_master_port_list
.
pop
()
else
:
answer
=
self
.
data_parallel_master_port
self
.
data_parallel_master_port
+=
1
return
answer
def
stateless_init_dp_group
(
self
)
->
ProcessGroup
:
...
...
@@ -241,6 +274,38 @@ class ParallelConfig:
return
hashlib
.
sha256
(
str
(
factors
).
encode
()).
hexdigest
()
def
__post_init__
(
self
)
->
None
:
# Forward deprecated fields to their new location
if
self
.
num_redundant_experts
is
not
None
:
self
.
eplb_config
.
num_redundant_experts
=
(
self
.
num_redundant_experts
)
logger
.
warning_once
(
"num_redundant_experts is deprecated and has been replaced "
"with eplb_config.num_redundant_experts. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect."
)
if
self
.
eplb_window_size
is
not
None
:
self
.
eplb_config
.
window_size
=
self
.
eplb_window_size
logger
.
warning_once
(
"eplb_window_size is deprecated and has been replaced "
"with eplb_config.window_size. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect."
)
if
self
.
eplb_step_interval
is
not
None
:
self
.
eplb_config
.
step_interval
=
self
.
eplb_step_interval
logger
.
warning_once
(
"eplb_step_interval is deprecated and has been replaced "
"with eplb_config.step_interval. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect."
)
if
self
.
eplb_log_balancedness
is
not
None
:
self
.
eplb_config
.
log_balancedness
=
self
.
eplb_log_balancedness
logger
.
warning_once
(
"eplb_log_balancedness is deprecated and has been replaced "
"with eplb_config.log_balancedness. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect."
)
# Continue with the rest of the initialization
self
.
world_size
=
self
.
pipeline_parallel_size
*
\
self
.
tensor_parallel_size
...
...
@@ -251,7 +316,10 @@ class ParallelConfig:
if
self
.
data_parallel_size
>
1
or
self
.
data_parallel_size_local
==
0
:
# Data parallel was specified in the engine args.
self
.
data_parallel_master_port
=
get_open_port
()
if
not
self
.
_data_parallel_master_port_list
:
self
.
_data_parallel_master_port_list
=
get_open_ports_list
(
5
)
self
.
data_parallel_master_port
=
\
self
.
_data_parallel_master_port_list
.
pop
()
if
not
(
0
<=
self
.
data_parallel_rank
<
self
.
data_parallel_size
):
raise
ValueError
(
...
...
@@ -279,10 +347,10 @@ class ParallelConfig:
raise
ValueError
(
"Expert parallelism load balancing is only supported on "
"CUDA devices now."
)
if
self
.
num_redundant_experts
<
0
:
if
self
.
eplb_config
.
num_redundant_experts
<
0
:
raise
ValueError
(
"num_redundant_experts must be non-negative, but got "
f
"
{
self
.
num_redundant_experts
}
."
)
f
"
{
self
.
eplb_config
.
num_redundant_experts
}
."
)
if
not
self
.
enable_expert_parallel
:
raise
ValueError
(
"enable_expert_parallel must be True to use EPLB."
)
...
...
@@ -293,10 +361,10 @@ class ParallelConfig:
f
"TP=
{
self
.
tensor_parallel_size
}
,DP=
{
self
.
data_parallel_size
}
."
)
else
:
if
self
.
num_redundant_experts
!=
0
:
if
self
.
eplb_config
.
num_redundant_experts
!=
0
:
raise
ValueError
(
"num_redundant_experts should be used with EPLB."
f
"
{
self
.
num_redundant_experts
}
."
)
f
"
{
self
.
eplb_config
.
num_redundant_experts
}
."
)
if
self
.
distributed_executor_backend
is
None
and
self
.
world_size
>
1
:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
...
...
@@ -342,23 +410,22 @@ class ParallelConfig:
def
use_ray
(
self
)
->
bool
:
return
self
.
distributed_executor_backend
==
"ray"
or
(
isinstance
(
self
.
distributed_executor_backend
,
type
)
and
self
.
distributed_executor_backend
.
uses_ray
)
and
getattr
(
self
.
distributed_executor_backend
,
"
uses_ray
"
,
False
)
)
@
model_validator
(
mode
=
'after'
)
def
_verify_args
(
self
)
->
Self
:
# Lazy import to avoid circular import
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.platforms
import
current_platform
if
self
.
distributed_executor_backend
not
in
(
"ray"
,
"mp"
,
"uni"
,
"external_launcher"
,
None
)
and
not
(
isinstance
(
if
self
.
distributed_executor_backend
is
not
None
and
not
isinstance
(
self
.
distributed_executor_backend
,
str
)
and
not
(
isinstance
(
self
.
distributed_executor_backend
,
type
)
and
issubclass
(
self
.
distributed_executor_backend
,
ExecutorBase
)):
raise
ValueError
(
"Unrecognized distributed executor backend "
f
"
{
self
.
distributed_executor_backend
}
. Supported "
"values are 'ray', 'mp' 'uni', 'external_launcher'
or
"
" custom ExecutorBase subclass."
)
"values are 'ray', 'mp' 'uni', 'external_launcher'
,
"
" custom ExecutorBase subclass
or its import path
."
)
if
self
.
use_ray
:
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
...
...
vllm/core/block/naive_block.py
View file @
a99300bd
...
...
@@ -207,7 +207,7 @@ class NaiveBlockAllocator(BlockAllocator):
Args:
absolute_id (int): The absolute block id for the block
in whole allocator.
in whole allocator.
Returns:
int: The zero-offset block id on certain device.
...
...
vllm/core/block/prefix_caching_block.py
View file @
a99300bd
...
...
@@ -61,7 +61,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Args:
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids(Optional[Iterable[int]], optional): An optional iterable of
block_ids
(Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
...
...
Prev
1
…
16
17
18
19
20
21
22
23
24
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment