Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
...@@ -5,6 +5,7 @@ from collections.abc import Iterable ...@@ -5,6 +5,7 @@ from collections.abc import Iterable
import torch.fx import torch.fx
from torch import SymInt from torch import SymInt
from torch.fx.experimental.symbolic_shapes import statically_known_true
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -116,12 +117,7 @@ class NoOpEliminationPass(VllmInductorPass): ...@@ -116,12 +117,7 @@ class NoOpEliminationPass(VllmInductorPass):
2. The dimensions both correspond to the same SymInt 2. The dimensions both correspond to the same SymInt
""" """
# Case 1 # Case 1
if isinstance(i_dim, int) and isinstance(dim, int): return statically_known_true(dim == i_dim)
return dim == i_dim
# Case 2
if isinstance(i_dim, SymInt) and isinstance(dim, SymInt):
return dim == i_dim
return False
def all_dims_equivalent( def all_dims_equivalent(
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt] self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
......
...@@ -5,6 +5,7 @@ import functools ...@@ -5,6 +5,7 @@ import functools
from torch import fx as fx from torch import fx as fx
from vllm import envs from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var ...@@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var
from .post_cleanup import PostCleanupPass from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import RMSNormQuantFusionPass from .fusion import RMSNormQuantFusionPass
...@@ -24,7 +31,11 @@ if current_platform.is_cuda(): ...@@ -24,7 +31,11 @@ if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass from .collective_fusion import AllReduceFusionPass, AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .inductor_pass import (
CustomGraphPass,
InductorPass,
get_pass_context,
)
from .noop_elimination import NoOpEliminationPass from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -70,13 +81,13 @@ class PostGradPassManager(CustomGraphPass): ...@@ -70,13 +81,13 @@ class PostGradPassManager(CustomGraphPass):
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
VllmInductorPass.dump_prefix = 0 # reset dump index VllmInductorPass.dump_prefix = 0 # reset dump index
shape = get_pass_context().runtime_shape compile_range = get_pass_context().compile_range
for pass_ in self.passes: for pass_ in self.passes:
if pass_.is_applicable(shape): if pass_.is_applicable_for_range(compile_range):
pass_(graph) pass_(graph)
VllmInductorPass.dump_prefix += 1 VllmInductorPass.dump_prefix += 1
else: else:
logger.debug("Skipping %s with shape %s", pass_, shape) logger.debug("Skipping %s with compile range %s", pass_, compile_range)
# post-cleanup goes before fix_functionalization # post-cleanup goes before fix_functionalization
# because it requires a functional graph # because it requires a functional graph
...@@ -105,8 +116,12 @@ class PostGradPassManager(CustomGraphPass): ...@@ -105,8 +116,12 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.fuse_norm_quant: if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)] self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_quant: if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)] self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_attn_quant: if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)] self.passes += [AttnFusionPass(config)]
...@@ -133,4 +148,8 @@ class PostGradPassManager(CustomGraphPass): ...@@ -133,4 +148,8 @@ class PostGradPassManager(CustomGraphPass):
state["passes"].append(pass_.uuid()) state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid()) state["passes"].append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
state["compile_range"] = str(get_pass_context().compile_range)
return InductorPass.hash_dict(state) return InductorPass.hash_dict(state)
...@@ -7,18 +7,18 @@ from typing import Any ...@@ -7,18 +7,18 @@ from typing import Any
import torch.fx as fx import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class ConcreteSizeEntry: class RangeEntry:
runtime_shape: int compile_range: Range
compiled: bool = False compiled: bool = False
runnable: Callable = None # type: ignore runnable: Callable = None # type: ignore
...@@ -31,7 +31,6 @@ class PiecewiseBackend: ...@@ -31,7 +31,6 @@ class PiecewiseBackend:
piecewise_compile_index: int, piecewise_compile_index: int,
total_piecewise_compiles: int, total_piecewise_compiles: int,
sym_shape_indices: list[int], sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend, vllm_backend: VllmBackend,
): ):
""" """
...@@ -54,68 +53,131 @@ class PiecewiseBackend: ...@@ -54,68 +53,131 @@ class PiecewiseBackend:
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1 self.is_full_graph = total_piecewise_compiles == 1
# TODO: we need to generalize encoder compilation to other models
self.is_encoder_compilation = vllm_backend.prefix in [
"Qwen2_5_VisionPatchEmbed",
"Qwen2_5_VisionPatchMerger",
"Qwen2_5_VisionBlock",
]
self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation:
# For encoder compilation we use the max int32 value
# to set the upper bound of the compile ranges
max_int32 = 2**31 - 1
last_compile_range = self.compile_ranges[-1]
assert (
last_compile_range.end
== vllm_config.scheduler_config.max_num_batched_tokens
)
self.compile_ranges[-1] = Range(
start=last_compile_range.start, end=max_int32
)
self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string)
self.first_run_finished = False
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa self.compile_sizes = self.compilation_config.compile_sizes
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
logger.debug_once(log_string)
self.sym_shape_indices = sym_shape_indices self.sym_shape_indices = sym_shape_indices
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" # the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}
# the entries for different shapes that we need to compile # to_be_compiled_ranges tracks the remaining ranges to compile,
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it # and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly. # We only keep compilation management inside this class directly.
for shape in self.compile_sizes: for size in self.compile_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry( range = Range(start=size, end=size)
runtime_shape=shape, if range not in self.compile_ranges:
runnable=self.compiled_graph_for_general_shape, self.range_entries[range] = RangeEntry(
compile_range=range,
)
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
) )
def check_for_ending_compilation(self): def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile # no specific sizes to compile
# save the hash of the inductor graph for the next run # save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file() self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config) end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any: def _fakify_args(self, args: list[Any]) -> list[Any]:
if not self.first_run_finished: # We need to pass fake example_inputs, otherwise torch.compile
self.first_run_finished = True # will fakify the example_inputs potentially causing some non dynamic
self.check_for_ending_compilation() # dimension to be be duck shaped to other existing shapes that have hints
return self.compiled_graph_for_general_shape(*args) # matching their values.
# This is problem because it can lead to unintended specializations!
runtime_shape = args[self.sym_shape_indices[0]] # if the new wrongly dynamic dim is specialized
# it will force specializing the whole shape
if runtime_shape not in self.concrete_size_entries: # torch.compile probably should not accept
# we don't need to do anything for this shape # non fake tensors as example inputs!
return self.compiled_graph_for_general_shape(*args) # See issue https://github.com/vllm-project/vllm/issues/27899
fake_example_inputs = []
for node in self.graph.graph.nodes:
# All place holders come first
if node.op == "placeholder":
fake_example_inputs.append(node.meta["example_value"])
else:
break
assert len(fake_example_inputs) == len(args)
return fake_example_inputs
def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
if not range_entry.compiled:
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
entry = self.concrete_size_entries[runtime_shape]
if not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments # args are real arguments
entry.runnable = self.vllm_backend.compiler_manager.compile( # fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args = (
self._fakify_args(args)
if not range_entry.compile_range.is_single_size()
else args
)
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph, self.graph,
args, args,
self.vllm_backend.inductor_config, self.vllm_backend.inductor_config,
self.compilation_config, self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index, graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles, num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
) )
# finished compilations for all required shapes self.check_for_ending_compilation()
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation() def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
else:
for range in self.compile_ranges:
if runtime_shape in range:
return self.range_entries[range]
return None
def __call__(self, *args) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} "
"[1, max_num_batched_tokens]"
)
return entry.runnable(*args) self._maybe_compile_for_range_entry(range_entry, args)
return range_entry.runnable(*args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fusion import empty_bf16
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherSiluAndMul
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
AITER_RMS_ADD_GROUP_QUANT_OP = (
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
)
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
class AiterRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm & group fp8 quant custom
ops into an aiter rms_norm_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_GROUP_QUANT_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
return at[0], at[1]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class AiterFusedAddRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
into a aiter rms_norm_with_add_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_ADD_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
)
at2 = self.quant_op(at1[0], 128)
# result, scale, residual
return at2[0], at2[1], at1[1]
def replacement(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_ADD_GROUP_QUANT_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
# result, scale, residual
return at[0], at[1], at[2]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + dynamic group fp8 quant
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
self.patterns
)
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, quant_op
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
fusion_patterns = [
AiterRMSFp8GroupQuantPattern,
AiterFusedAddRMSFp8GroupQuantPattern,
]
return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
"""
def __init__(self, quant_op: OpOverload):
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
):
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
):
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
inputs = [
self.silu_and_mul_matcher.inputs()[0],
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)
...@@ -9,6 +9,7 @@ import torch.fx as fx ...@@ -9,6 +9,7 @@ import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -333,7 +334,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ...@@ -333,7 +334,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool: def is_applicable_for_range(self, compile_range: Range) -> bool:
# When sequence parallelism is enabled, the residual tensor from RMSNorm # When sequence parallelism is enabled, the residual tensor from RMSNorm
# needs to be split along the sequence dimension. However, this dimension # needs to be split along the sequence dimension. However, this dimension
# is symbolic during piecewise compilation, and splitting symbolic shapes # is symbolic during piecewise compilation, and splitting symbolic shapes
...@@ -353,7 +354,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ...@@ -353,7 +354,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
): ):
return True return True
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0 return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import os import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from types import CodeType from types import CodeType
from typing import Any from typing import Any
...@@ -13,7 +13,9 @@ import torch._C._dynamo.guards ...@@ -13,7 +13,9 @@ import torch._C._dynamo.guards
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -92,12 +94,29 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -92,12 +94,29 @@ class TorchCompileWithNoGuardsWrapper:
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs):
if self.layerwise_nvtx_tracing_enabled:
args_list = list(args)
kwargs_dict = dict(kwargs)
with layerwise_nvtx_marker_context(
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
self,
in_tensor=args_list,
kwargs=kwargs_dict,
) as ctx:
ctx.result = callable_fn(*args, **kwargs)
return ctx.result
return callable_fn(*args, **kwargs)
def __init__(self): def __init__(self):
self.compiled = False self.compiled = False
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config self.vllm_config = vllm_config
mode = vllm_config.compilation_config.mode mode = vllm_config.compilation_config.mode
self.layerwise_nvtx_tracing_enabled = (
vllm_config.observability_config.enable_layerwise_nvtx_tracing
)
if mode is None: if mode is None:
raise RuntimeError("Compilation mode cannot be NO_COMPILATION") raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
...@@ -107,23 +126,49 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -107,23 +126,49 @@ class TorchCompileWithNoGuardsWrapper:
if isinstance(backend, str) and backend == "inductor": if isinstance(backend, str) and backend == "inductor":
options = vllm_config.compilation_config.inductor_compile_config options = vllm_config.compilation_config.inductor_compile_config
self.first_compile = True
self.evaluate_guards = (
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
)
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
if mode != CompilationMode.STOCK_TORCH_COMPILE: if mode != CompilationMode.STOCK_TORCH_COMPILE:
# Drop all the guards. # Drop all the guards.
options["guard_filter_fn"] = lambda x: [False for _ in x] if self.evaluate_guards:
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"compilation_config.dynamic_shapes_config.evaluate_guards "
"requires VLLM_USE_BYTECODE_HOOK=0. "
)
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False if envs.VLLM_USE_AOT_COMPILE:
from vllm.compilation.decorators import DynamicShapesType # disabled until https://github.com/pytorch/pytorch/pull/169239
# is picked up.
assert ds_type != DynamicShapesType.BACKED, (
"evaluate_guards for backed shapes requires "
"VLLM_USE_AOT_COMPILE=False. "
)
options["guard_filter_fn"] = lambda x: [
entry.guard_type == "SHAPE_ENV" for entry in x
]
else:
options["guard_filter_fn"] = lambda x: [False for _ in x]
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
compiled_ptr: Any = self.forward compiled_ptr: Any = self.forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
if ds_type == DynamicShapesType.UNBACKED: if ds_type == DynamicShapesType.UNBACKED:
if envs.VLLM_USE_BYTECODE_HOOK: # reason is that bytecode does torch._dynamo.eval_frame.
# reason is that bytecode does this hack torch._dynamo.eval_frame. # remove_from_cache(self.original_code_object()) to force a new
# remove_from_cache(self.original_code_object()) to force a new # re-compilation. And if we use
# re-compilation. # compiled_ptr = self.check_invariants_and_forward
raise ValueError( # it will reset all entries.
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. " assert not envs.VLLM_USE_BYTECODE_HOOK, (
) "UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
)
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
compiled_ptr = self.check_invariants_and_forward compiled_ptr = self.check_invariants_and_forward
if envs.VLLM_USE_AOT_COMPILE: if envs.VLLM_USE_AOT_COMPILE:
...@@ -168,13 +213,25 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -168,13 +213,25 @@ class TorchCompileWithNoGuardsWrapper:
# Make sure a compilation is triggered by clearing dynamo # Make sure a compilation is triggered by clearing dynamo
# cache. # cache.
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object()) torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
return self._compiled_callable(*args, **kwargs) return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
else: else:
with self._dispatch_to_compiled_code(): with self._dispatch_to_compiled_code():
return self.forward(*args, **kwargs) return self._call_with_optional_nvtx_range(
self.forward, *args, **kwargs
)
else: else:
with _compilation_context(): ctx = (
return self._compiled_callable(*args, **kwargs) nullcontext()
if self.first_compile or not self.evaluate_guards
else torch.compiler.set_stance("fail_on_recompile")
)
self.first_compile = False
with _compilation_context(), ctx:
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
@abstractmethod @abstractmethod
def forward(self, *args, **kwargs): ... def forward(self, *args, **kwargs): ...
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config.attention import AttentionConfig
from vllm.config.cache import CacheConfig from vllm.config.cache import CacheConfig
from vllm.config.compilation import ( from vllm.config.compilation import (
CompilationConfig, CompilationConfig,
...@@ -23,6 +24,7 @@ from vllm.config.multimodal import MultiModalConfig ...@@ -23,6 +24,7 @@ from vllm.config.multimodal import MultiModalConfig
from vllm.config.observability import ObservabilityConfig from vllm.config.observability import ObservabilityConfig
from vllm.config.parallel import EPLBConfig, ParallelConfig from vllm.config.parallel import EPLBConfig, ParallelConfig
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.config.profiler import ProfilerConfig
from vllm.config.scheduler import SchedulerConfig from vllm.config.scheduler import SchedulerConfig
from vllm.config.speculative import SpeculativeConfig from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.speech_to_text import SpeechToTextConfig
...@@ -46,6 +48,8 @@ from vllm.config.vllm import ( ...@@ -46,6 +48,8 @@ from vllm.config.vllm import (
# __all__ should only contain classes and functions. # __all__ should only contain classes and functions.
# Types and globals should be imported from their respective modules. # Types and globals should be imported from their respective modules.
__all__ = [ __all__ = [
# From vllm.config.attention
"AttentionConfig",
# From vllm.config.cache # From vllm.config.cache
"CacheConfig", "CacheConfig",
# From vllm.config.compilation # From vllm.config.compilation
...@@ -86,6 +90,8 @@ __all__ = [ ...@@ -86,6 +90,8 @@ __all__ = [
"SpeechToTextConfig", "SpeechToTextConfig",
# From vllm.config.structured_outputs # From vllm.config.structured_outputs
"StructuredOutputsConfig", "StructuredOutputsConfig",
# From vllm.config.profiler
"ProfilerConfig",
# From vllm.config.utils # From vllm.config.utils
"ConfigType", "ConfigType",
"SupportsMetricsInfo", "SupportsMetricsInfo",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config
from vllm.logger import init_logger
logger = init_logger(__name__)
@config
@dataclass
class AttentionConfig:
"""Configuration for attention mechanisms in vLLM."""
backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically."""
flash_attn_version: Literal[2, 3] | None = None
"""Force vllm to use a specific flash-attention version (2 or 3).
Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False
"""Use separate prefill and decode kernels for attention instead of
the unified triton kernel."""
flash_attn_max_num_splits_for_cuda_graph: int = 32
"""Flash Attention max number splits for cuda graph decode."""
use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill."""
use_trtllm_ragged_deepseek_prefill: bool = False
"""Whether to use TRTLLM ragged deepseek prefill."""
use_trtllm_attention: bool | None = None
"""If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
disable_flashinfer_prefill: bool = False
"""Whether to disable flashinfer prefill."""
disable_flashinfer_q_quantization: bool = False
"""If set, when using fp8 kv, do not quantize Q to fp8."""
def compute_hash(self) -> str:
"""
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
from vllm.config.utils import get_hash_factors, hash_factors
ignored_factors: list[str] = []
factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)
@field_validator("backend", mode="before")
@classmethod
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return AttentionBackendEnum[value.upper()]
return value
def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
"""Set field from env var if set, with deprecation warning."""
from vllm import envs
if envs.is_set(env_var_name):
value = getattr(envs, env_var_name)
if field_name == "backend":
value = self.validate_backend_before(value)
setattr(self, field_name, value)
logger.warning_once(
"Using %s environment variable is deprecated and will be removed in "
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
"--attention-config.%s command line argument or "
"AttentionConfig(%s=...) config field instead.",
env_var_name,
field_name,
field_name,
)
def __post_init__(self) -> None:
self._set_from_env_if_set("backend", "VLLM_ATTENTION_BACKEND")
self._set_from_env_if_set("flash_attn_version", "VLLM_FLASH_ATTN_VERSION")
self._set_from_env_if_set(
"use_prefill_decode_attention", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION"
)
self._set_from_env_if_set(
"flash_attn_max_num_splits_for_cuda_graph",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
)
self._set_from_env_if_set("use_cudnn_prefill", "VLLM_USE_CUDNN_PREFILL")
self._set_from_env_if_set(
"use_trtllm_ragged_deepseek_prefill",
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
)
self._set_from_env_if_set("use_trtllm_attention", "VLLM_USE_TRTLLM_ATTENTION")
self._set_from_env_if_set(
"disable_flashinfer_prefill", "VLLM_DISABLE_FLASHINFER_PREFILL"
)
self._set_from_env_if_set(
"disable_flashinfer_q_quantization",
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
)
...@@ -29,8 +29,8 @@ CacheDType = Literal[ ...@@ -29,8 +29,8 @@ CacheDType = Literal[
"fp8_inc", "fp8_inc",
"fp8_ds_mla", "fp8_ds_mla",
] ]
MambaDType = Literal["auto", "float32"] MambaDType = Literal["auto", "float32", "float16"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"] KVOffloadingBackend = Literal["native", "lmcache"]
...@@ -77,9 +77,21 @@ class CacheConfig: ...@@ -77,9 +77,21 @@ class CacheConfig:
"""Whether to enable prefix caching.""" """Whether to enable prefix caching."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n """Set the hash algorithm for prefix caching:\n
- "sha256" uses Pickle for object serialization before hashing.\n - "sha256" uses Pickle for object serialization before hashing. This is the
current default, as SHA256 is the most secure choice to avoid potential
hash collisions.\n
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It - "sha256_cbor" provides a reproducible, cross-language compatible hash. It
serializes objects using canonical CBOR and hashes them with SHA-256.""" serializes objects using canonical CBOR and hashes them with SHA-256.\n
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
non-cryptographic hashing. Requires the optional ``xxhash`` package.
IMPORTANT: Use of a hashing algorithm that is not considered
cryptographically secure theoretically increases the risk of hash collisions,
which can cause undefined behavior or even leak private information in
multi-tenant environments. Even if collisions are still very unlikely, it is
important to consider your security risk tolerance against the performance
benefits before turning this on.\n
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
reproducible hashing. Requires the optional ``xxhash`` package."""
cpu_offload_gb: float = Field(default=0, ge=0) cpu_offload_gb: float = Field(default=0, ge=0)
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means """The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to no offloading. Intuitively, this argument can be seen as a virtual way to
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import enum import enum
from collections import Counter from collections import Counter
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict, field from dataclasses import field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal from typing import TYPE_CHECKING, Any, ClassVar, Literal
...@@ -13,7 +13,13 @@ from pydantic.dataclasses import dataclass ...@@ -13,7 +13,13 @@ from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config, handle_deprecated from vllm.config.utils import (
Range,
config,
get_hash_factors,
handle_deprecated,
hash_factors,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
...@@ -173,6 +179,9 @@ class PassConfig: ...@@ -173,6 +179,9 @@ class PassConfig:
""" """
MiB = 1024 * 1024 MiB = 1024 * 1024
FI_SUPPORTED_WORLD_SIZES = [2, 4, 8]
if world_size not in FI_SUPPORTED_WORLD_SIZES:
return None
max_size_mb = self.fi_allreduce_fusion_max_size_mb max_size_mb = self.fi_allreduce_fusion_max_size_mb
if max_size_mb is None: if max_size_mb is None:
max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size)
...@@ -196,7 +205,16 @@ class PassConfig: ...@@ -196,7 +205,16 @@ class PassConfig:
Any new fields that affect compilation should be added to the hash. Any new fields that affect compilation should be added to the hash.
Any future fields that don't affect compilation should be excluded. Any future fields that don't affect compilation should be excluded.
""" """
return InductorPass.hash_dict(asdict(self))
ignored_fields = [
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
]
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
@field_validator( @field_validator(
"fuse_norm_quant", "fuse_norm_quant",
...@@ -267,14 +285,6 @@ class PassConfig: ...@@ -267,14 +285,6 @@ class PassConfig:
"v0.13.0 or v1.0.0, whichever is sooner", "v0.13.0 or v1.0.0, whichever is sooner",
) )
# Force old flags to None to ensure they are not used
self.enable_fusion = None
self.enable_attn_fusion = None
self.enable_noop = None
self.enable_sequence_parallelism = None
self.enable_async_tp = None
self.enable_fi_allreduce_fusion = None
if not self.eliminate_noops: if not self.eliminate_noops:
if self.fuse_norm_quant or self.fuse_act_quant: if self.fuse_norm_quant or self.fuse_act_quant:
logger.warning_once( logger.warning_once(
...@@ -334,7 +344,18 @@ class DynamicShapesConfig: ...@@ -334,7 +344,18 @@ class DynamicShapesConfig:
backed/unbacked. backed/unbacked.
""" """
# TODO add a debug mode to fail evaluate_guards: bool = False
"""
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
guarding on it. When True, dynamic shape guards are not dropped from dynamo.
And a failure will be triggered if a recompilation ever happens due to that.
This mode requires VLLM_USE_BYTECODE_HOOK to be 0.
Enabling this allow observing the dynamic shapes guards in the tlparse
artifacts also.
When type is backed, aot_compile must be disabled for this mode to work.
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
...@@ -378,6 +399,8 @@ class CompilationConfig: ...@@ -378,6 +399,8 @@ class CompilationConfig:
[vllm.config.CompilationConfig.cudagraph_copy_inputs] [vllm.config.CompilationConfig.cudagraph_copy_inputs]
- Inductor compilation: - Inductor compilation:
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`compile_ranges_split_points`]
[vllm.config.CompilationConfig.compile_ranges_split_points]
- [`inductor_compile_config`] - [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config] [vllm.config.CompilationConfig.inductor_compile_config]
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
...@@ -443,8 +466,8 @@ class CompilationConfig: ...@@ -443,8 +466,8 @@ class CompilationConfig:
We use string to avoid serialization issues when using compilation in a We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation mode is 1 or 2, the backend is distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the used for the compilation directly (it sees the whole graph). When the
compilation mode is 3, the backend supports both whole graph and piecewise compilation mode is 3, the backend supports both whole graph and piecewise
compilation, available backends include eager, inductor, and custom backends, compilation, available backends include eager, inductor, and custom backends,
the latter of which can be defined via `get_compile_backend`. Furthermore, the latter of which can be defined via `get_compile_backend`. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for use_inductor_graph_partition is off. Note that the default options for
...@@ -491,6 +514,21 @@ class CompilationConfig: ...@@ -491,6 +514,21 @@ class CompilationConfig:
to integers, it also supports "cudagraph_capture_sizes" to to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture.""" specify the sizes for cudagraph capture."""
compile_ranges_split_points: list[int] | None = None
"""Split points that represent compile ranges for inductor.
The compile ranges are
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]].
If a range overlaps with the compile size, graph for compile size
will be prioritized, i.e. if we have a range [1, 8] and a compile size 4,
graph for compile size 4 will be compiled and used instead of the graph
for range [1, 8].
"""
inductor_compile_config: dict = field(default_factory=dict) inductor_compile_config: dict = field(default_factory=dict)
"""Additional configurations for inductor. """Additional configurations for inductor.
- None: use default configurations.""" - None: use default configurations."""
...@@ -939,7 +977,9 @@ class CompilationConfig: ...@@ -939,7 +977,9 @@ class CompilationConfig:
# May get recomputed in the model runner if adjustment is needed for spec-decode # May get recomputed in the model runner if adjustment is needed for spec-decode
self.compute_bs_to_padded_graph_size() self.compute_bs_to_padded_graph_size()
def set_splitting_ops_for_v1(self): def set_splitting_ops_for_v1(
self, all2all_backend: str | None = None, data_parallel_size: int | None = None
):
# To compatible with OOT hardware plugin platform (for example vllm-ascend) # To compatible with OOT hardware plugin platform (for example vllm-ascend)
# which currently only supports sequence parallelism in eager mode. # which currently only supports sequence parallelism in eager mode.
if self.mode != CompilationMode.VLLM_COMPILE: if self.mode != CompilationMode.VLLM_COMPILE:
...@@ -954,50 +994,83 @@ class CompilationConfig: ...@@ -954,50 +994,83 @@ class CompilationConfig:
"mode is CompilationMode.VLLM_COMPILE" "mode is CompilationMode.VLLM_COMPILE"
) )
if self.use_inductor_graph_partition: added_default_splitting_ops = False
self.set_splitting_ops_for_inductor_graph_partition()
return
if self.pass_config.fuse_attn_quant: if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
# here use_inductor_graph_partition is False
self.set_splitting_ops_for_attn_fusion() self.set_splitting_ops_for_attn_fusion()
return else:
if self.splitting_ops is None:
if self.splitting_ops is None: # NOTE: When using full cudagraph, instead of setting an empty
# NOTE: When using full cudagraph, instead of setting an empty # list and capture the full cudagraph inside the flattened fx
# list and capture the full cudagraph inside the flattened fx # graph, we keep the piecewise fx graph structure but capture
# graph, we keep the piecewise fx graph structure but capture # the full cudagraph outside the fx graph. This reduces some
# the full cudagraph outside the fx graph. This reduces some # cpu overhead when the runtime batch_size is not cudagraph
# cpu overhead when the runtime batch_size is not cudagraph # captured. see https://github.com/vllm-project/vllm/pull/20059
# captured. see https://github.com/vllm-project/vllm/pull/20059 # for details. Make a copy to avoid mutating the class-level
# for details. Make a copy to avoid mutating the class-level # list via reference.
# list via reference. self.splitting_ops = list(self._attention_ops)
self.splitting_ops = list(self._attention_ops) added_default_splitting_ops = True
elif len(self.splitting_ops) == 0: elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty splitting_ops")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once( logger.warning_once(
"Piecewise compilation with empty splitting_ops do not" "Using piecewise compilation with empty splitting_ops"
"contains piecewise cudagraph. Setting cudagraph_"
"mode to NONE. Hint: If you are using attention backends "
"that support cudagraph, consider manually setting "
"cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
"full cudagraphs."
) )
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not"
"contains piecewise cudagraph. Setting cudagraph_"
"mode to NONE. Hint: If you are using attention "
"backends that support cudagraph, consider manually "
"setting cudagraph_mode to FULL or FULL_DECODE_ONLY "
"to enable full cudagraphs."
)
self.cudagraph_mode = CUDAGraphMode.NONE
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do "
"not contains piecewise cudagraph. Setting "
"cudagraph_mode to FULL."
)
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
# split MoE ops for cudagraph
moe_ops = [
"vllm::moe_forward",
"vllm::moe_forward_shared",
]
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
dp_size = data_parallel_size if data_parallel_size is not None else 1
need_moe_splitting = (
backend == "deepep_high_throughput"
and dp_size > 1
# pure attn-fusion without inductor partition deliberately disables
# piecewise graphs and MoE splitting.
and not (
self.pass_config.fuse_attn_quant
and not self.use_inductor_graph_partition
)
)
if need_moe_splitting and self.cudagraph_mode != CUDAGraphMode.NONE:
# if we just initialized default splitting_ops for this config,
# automatically append the MoE ops
if added_default_splitting_ops:
for op in moe_ops:
if op not in self.splitting_ops:
self.splitting_ops.append(op)
# make sure MoE ops are split out
if not any(op in self.splitting_ops for op in moe_ops):
self.cudagraph_mode = CUDAGraphMode.NONE self.cudagraph_mode = CUDAGraphMode.NONE
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
logger.warning_once( logger.warning_once(
"Piecewise compilation with empty splitting_ops do not " "DeepEP high throughput backend with data_parallel_size > 1 "
"contains piecewise cudagraph. Setting cudagraph_mode " "requires splitting MoE ops from cudagraphs. Please ensure "
"to FULL." "'vllm::moe_forward' or 'vllm::moe_forward_shared' are "
"present in CompilationConfig.splitting_ops."
) )
self.cudagraph_mode = CUDAGraphMode.FULL elif self.cudagraph_mode.has_full_cudagraphs():
self.splitting_ops = [] # fall back to piecewise when MoE splitting is required.
self.cudagraph_mode = CUDAGraphMode.PIECEWISE
def set_splitting_ops_for_inductor_graph_partition(self):
assert self.use_inductor_graph_partition
if self.splitting_ops is None:
self.splitting_ops = list(self._attention_ops)
def set_splitting_ops_for_attn_fusion(self): def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.fuse_attn_quant assert self.pass_config.fuse_attn_quant
...@@ -1152,3 +1225,13 @@ class CompilationConfig: ...@@ -1152,3 +1225,13 @@ class CompilationConfig:
self.bs_to_padded_graph_size[bs] = start self.bs_to_padded_graph_size[bs] = start
else: else:
self.bs_to_padded_graph_size[bs] = end self.bs_to_padded_graph_size[bs] = end
def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None:
return []
split_points = sorted(set(self.compile_ranges_split_points))
return [
Range(start=s + 1, end=e)
for s, e in zip([0] + split_points[:-1], split_points)
]
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
from dataclasses import InitVar, field from dataclasses import InitVar, field
from importlib.util import find_spec from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch import torch
...@@ -37,15 +37,13 @@ from vllm.transformers_utils.config import ( ...@@ -37,15 +37,13 @@ from vllm.transformers_utils.config import (
uses_xdrope_dim, uses_xdrope_dim,
) )
from vllm.transformers_utils.gguf_utils import ( from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import (
is_gguf, is_gguf,
is_remote_gguf, is_remote_gguf,
maybe_model_redirect, maybe_patch_hf_config_from_gguf,
split_remote_gguf, split_remote_gguf,
) )
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype from vllm.utils.torch_utils import common_broadcastable_dtype
...@@ -86,7 +84,7 @@ TaskOption = Literal[ ...@@ -86,7 +84,7 @@ TaskOption = Literal[
"transcription", "transcription",
"draft", "draft",
] ]
TokenizerMode = Literal["auto", "hf", "slow", "mistral"] TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[ LogprobsMode = Literal[
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs" "raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
...@@ -139,10 +137,12 @@ class ModelConfig: ...@@ -139,10 +137,12 @@ class ModelConfig:
name or path will be used.""" name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto" tokenizer_mode: TokenizerMode | str = "auto"
"""Tokenizer mode:\n """Tokenizer mode:\n
- "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.\n - "auto" will use the tokenizer from `mistral_common` for Mistral models
if available, otherwise it will use the "hf" tokenizer.\n
- "hf" will use the fast tokenizer if available.\n - "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n - "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n - "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- Other custom values can be supported via plugins.""" - Other custom values can be supported via plugins."""
trust_remote_code: bool = False trust_remote_code: bool = False
"""Trust remote code (e.g., from HuggingFace) when downloading the model """Trust remote code (e.g., from HuggingFace) when downloading the model
...@@ -471,18 +471,6 @@ class ModelConfig: ...@@ -471,18 +471,6 @@ class ModelConfig:
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
if (
(backend := envs.VLLM_ATTENTION_BACKEND)
and backend == "FLASHINFER"
and find_spec("flashinfer") is None
):
raise ValueError(
"VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer "
"module was not found. See "
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
"for instructions on how to install it."
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
if self.override_attention_dtype is not None and not current_platform.is_rocm(): if self.override_attention_dtype is not None and not current_platform.is_rocm():
...@@ -531,7 +519,11 @@ class ModelConfig: ...@@ -531,7 +519,11 @@ class ModelConfig:
if task == "classify": if task == "classify":
return "classify" return "classify"
if task == "reward": if task == "reward":
return "reward" logger.warning(
"Pooling models now default support all pooling; "
"you can use it without any settings."
)
return "embed"
if task == "score": if task == "score":
new_task = self._get_default_pooling_task(architectures) new_task = self._get_default_pooling_task(architectures)
return "classify" if new_task == "classify" else "embed" return "classify" if new_task == "classify" else "embed"
...@@ -1233,6 +1225,19 @@ class ModelConfig: ...@@ -1233,6 +1225,19 @@ class ModelConfig:
) )
return False return False
@cached_property
def is_mm_prefix_lm(self) -> bool:
"""Whether to use bidirectional attention for mm positions."""
MM_PREFIX_LM_MODELS = (
"gemma3",
# TODO(Isotr0py): Disable paligemma for now before
# we supports soft cap attention for FlexAttention
# "paligemma",
)
if not hasattr(self.hf_config, "model_type"):
return False
return self.hf_config.model_type in MM_PREFIX_LM_MODELS
def get_head_size(self) -> int: def get_head_size(self) -> int:
# TODO remove hard code # TODO remove hard code
if self.is_deepseek_mla: if self.is_deepseek_mla:
...@@ -1784,20 +1789,22 @@ class ModelConfig: ...@@ -1784,20 +1789,22 @@ class ModelConfig:
return False return False
elif attn_type == "decoder": elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower() pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]: if pooling_type in ["mean", "step", "cls"]:
logger.debug( logger.debug(
"Pooling models with %s pooling does not " "Pooling models with %s pooling does not "
"support chunked prefill.", "support chunked prefill.",
pooling_type, pooling_type,
) )
return False return False
else: elif pooling_type in ["all", "last"]:
# pooling_type == "last"
logger.debug( logger.debug(
"Pooling models with causal attn and last pooling support " "Pooling models with causal attn and %s pooling support "
"chunked prefill." "chunked prefill.",
pooling_type,
) )
return True return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid, # vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types. # attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder" return attn_type != "encoder_decoder"
...@@ -1821,20 +1828,22 @@ class ModelConfig: ...@@ -1821,20 +1828,22 @@ class ModelConfig:
return False return False
elif attn_type == "decoder": elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower() pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]: if pooling_type in ["mean", "step", "cls"]:
logger.debug( logger.debug(
"Pooling models with %s pooling does not " "Pooling models with %s pooling does not "
"support prefix caching.", "support prefix caching.",
pooling_type, pooling_type,
) )
return False return False
else: elif pooling_type in ["all", "last"]:
# pooling_type == "last"
logger.debug( logger.debug(
"Pooling models with causal attn and last pooling support " "Pooling models with causal attn and %s pooling support "
"prefix caching." "prefix caching.",
pooling_type,
) )
return True return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid, # vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types. # attention_free or encoder_decoder attn types.
return False return False
...@@ -1897,8 +1906,8 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ ...@@ -1897,8 +1906,8 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForImageClassification", ("pooling", "classify")), ("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")), ("ForVideoClassification", ("pooling", "classify")),
("ClassificationModel", ("pooling", "classify")), ("ClassificationModel", ("pooling", "classify")),
("ForRewardModeling", ("pooling", "reward")), ("ForRewardModeling", ("pooling", "embed")),
("RewardModel", ("pooling", "reward")), ("RewardModel", ("pooling", "embed")),
# Let other `*Model`s take priority # Let other `*Model`s take priority
("Model", ("pooling", "embed")), ("Model", ("pooling", "embed")),
] ]
......
...@@ -55,6 +55,15 @@ class ObservabilityConfig: ...@@ -55,6 +55,15 @@ class ObservabilityConfig:
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1) kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks.""" """Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
cudagraph_metrics: bool = False
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
dispatch modes, and their observed frequencies at every logging interval)."""
enable_layerwise_nvtx_tracing: bool = False
"""Enable layerwise NVTX tracing. This traces the execution of each layer or
module in the model and attach informations such as input/output shapes to
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
@cached_property @cached_property
def collect_model_forward_time(self) -> bool: def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request.""" """Whether to collect model forward time for the request."""
......
...@@ -35,6 +35,7 @@ logger = init_logger(__name__) ...@@ -35,6 +35,7 @@ logger = init_logger(__name__)
ExpertPlacementStrategy = Literal["linear", "round_robin"] ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"] DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"]
@config @config
...@@ -65,6 +66,9 @@ class EPLBConfig: ...@@ -65,6 +66,9 @@ class EPLBConfig:
Whether to use non-blocking EPLB. Whether to use non-blocking EPLB.
""" """
policy: EPLBPolicyOption = "default"
"""The policy type for expert parallel load balancing (EPLB)."""
@config @config
@dataclass @dataclass
...@@ -180,13 +184,14 @@ class ParallelConfig: ...@@ -180,13 +184,14 @@ class ParallelConfig:
distributed_executor_backend: ( distributed_executor_backend: (
str | DistributedExecutorBackend | type[Executor] | None str | DistributedExecutorBackend | type[Executor] | None
) = None ) = None
"""Backend to use for distributed model """Backend to use for distributed model workers, either "ray" or "mp"
workers, either "ray" or "mp" (multiprocessing). If the product (multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size
of pipeline_parallel_size and tensor_parallel_size is less than is less than or equal to the number of GPUs available, "mp" will be used to
or equal to the number of GPUs available, "mp" will be used to keep processing on a single host. Otherwise, an error will be raised. To use "mp"
keep processing on a single host. Otherwise, this will default you must also set nnodes, and to use "ray" you must manually set
to "ray" if Ray is installed and fail otherwise. Note that tpu distributed_executor_backend to "ray".
only support Ray for distributed inference."""
Note that tpu only support Ray for distributed inference."""
worker_cls: str = "auto" worker_cls: str = "auto"
"""The full name of the worker class to use. If "auto", the worker class """The full name of the worker class to use. If "auto", the worker class
...@@ -562,8 +567,11 @@ class ParallelConfig: ...@@ -562,8 +567,11 @@ class ParallelConfig:
): ):
gpu_count = cuda_device_count_stateless() gpu_count = cuda_device_count_stateless()
raise ValueError( raise ValueError(
f"Tensor parallel size ({self.world_size}) cannot be " f"World size ({self.world_size}) is larger than the number of "
f"larger than the number of available GPUs ({gpu_count})." f"available GPUs ({gpu_count}) in this node. If this is "
"intentional and you are using:\n"
"- ray, set '--distributed-executor-backend ray'.\n"
"- multiprocessing, set '--nnodes' appropriately."
) )
elif self.data_parallel_backend == "ray": elif self.data_parallel_backend == "ray":
logger.info( logger.info(
...@@ -593,10 +601,14 @@ class ParallelConfig: ...@@ -593,10 +601,14 @@ class ParallelConfig:
"max_parallel_loading_workers is currently " "max_parallel_loading_workers is currently "
"not supported and will be ignored." "not supported and will be ignored."
) )
if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1: allowed_backends = ("mp", "uni", "external_launcher")
if (
self.distributed_executor_backend not in allowed_backends
and self.nnodes > 1
):
raise ValueError( raise ValueError(
"nnodes > 1 can only be set when distributed executor " "nnodes > 1 can only be set when distributed executor "
"backend is mp or uni." "backend is mp, uni or external_launcher."
) )
@property @property
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any, Literal
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
ProfilerKind = Literal["torch", "cuda"]
@config
@dataclass
class ProfilerConfig:
"""Dataclass which contains profiler config for the engine."""
profiler: ProfilerKind | None = None
"""Which profiler to use. Defaults to None. Options are:
- 'torch': Use PyTorch profiler.\n
- 'cuda': Use CUDA profiler."""
torch_profiler_dir: str = ""
"""Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
worker's traces (CPU & GPU) will be saved under this directory. Note that
it must be an absolute path."""
torch_profiler_with_stack: bool = True
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
torch_profiler_with_flops: bool = False
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
torch_profiler_use_gzip: bool = True
"""If `True`, saves torch profiler traces in gzip format. Enabled by default"""
torch_profiler_dump_cuda_time_total: bool = True
"""If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""
torch_profiler_record_shapes: bool = False
"""If `True`, records tensor shapes in the torch profiler. Disabled by default."""
torch_profiler_with_memory: bool = False
"""If `True`, enables memory profiling in the torch profiler.
Disabled by default."""
ignore_frontend: bool = False
"""If `True`, disables the front-end profiling of AsyncLLM when using the
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
since the front-end profiling does not track iterations and will capture the
entire range.
"""
delay_iterations: int = Field(default=0, ge=0)
"""Number of engine iterations to skip before starting profiling.
Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
"""
max_iterations: int = Field(default=0, ge=0)
"""Maximum number of engine iterations to profile after starting profiling.
Defaults to 0, meaning no limit.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
"""Get field from env var if set, with deprecation warning."""
if envs.is_set(env_var_name):
value = getattr(envs, env_var_name)
logger.warning_once(
"Using %s environment variable is deprecated and will be removed in "
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
"--profiler-config.%s command line argument or "
"ProfilerConfig(%s=...) config field instead.",
env_var_name,
field_name,
field_name,
)
return value
return None
def _set_from_env_if_set(
self,
field_name: str,
env_var_name: str,
to_bool: bool = True,
to_int: bool = False,
) -> None:
"""Set field from env var if set, with deprecation warning."""
value = self._get_from_env_if_set(field_name, env_var_name)
if value is not None:
if to_bool:
value = value == "1"
if to_int:
value = int(value)
setattr(self, field_name, value)
@model_validator(mode="after")
def _validate_profiler_config(self) -> Self:
maybe_use_cuda_profiler = self._get_from_env_if_set(
"profiler", "VLLM_TORCH_CUDA_PROFILE"
)
if maybe_use_cuda_profiler is not None:
self.profiler = "cuda" if maybe_use_cuda_profiler == "1" else None
else:
self._set_from_env_if_set(
"torch_profiler_dir", "VLLM_TORCH_PROFILER_DIR", to_bool=False
)
if self.torch_profiler_dir:
self.profiler = "torch"
self._set_from_env_if_set(
"torch_profiler_record_shapes",
"VLLM_TORCH_PROFILER_RECORD_SHAPES",
)
self._set_from_env_if_set(
"torch_profiler_with_memory",
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY",
)
self._set_from_env_if_set(
"torch_profiler_with_stack",
"VLLM_TORCH_PROFILER_WITH_STACK",
)
self._set_from_env_if_set(
"torch_profiler_with_flops",
"VLLM_TORCH_PROFILER_WITH_FLOPS",
)
self._set_from_env_if_set(
"ignore_frontend",
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM",
)
self._set_from_env_if_set(
"torch_profiler_use_gzip",
"VLLM_TORCH_PROFILER_USE_GZIP",
)
self._set_from_env_if_set(
"torch_profiler_dump_cuda_time_total",
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL",
)
self._set_from_env_if_set(
"delay_iterations", "VLLM_PROFILER_DELAY_ITERS", to_bool=False, to_int=True
)
self._set_from_env_if_set(
"max_iterations", "VLLM_PROFILER_MAX_ITERS", to_bool=False, to_int=True
)
has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
logger.warning_once(
"Using 'torch' profiler with delay_iterations or max_iterations "
"while ignore_frontend is False may result in high overhead."
)
profiler_dir = self.torch_profiler_dir
if profiler_dir and self.profiler != "torch":
raise ValueError(
"torch_profiler_dir is only applicable when profiler is set to 'torch'"
)
if self.profiler == "torch" and not profiler_dir:
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
if profiler_dir:
is_gs_path = (
profiler_dir.startswith("gs://")
and profiler_dir[5:]
and profiler_dir[5] != "/"
)
if not is_gs_path:
self.torch_profiler_dir = os.path.abspath(
os.path.expanduser(profiler_dir)
)
return self
...@@ -337,6 +337,7 @@ class SpeculativeConfig: ...@@ -337,6 +337,7 @@ class SpeculativeConfig:
enforce_eager=self.target_model_config.enforce_eager, enforce_eager=self.target_model_config.enforce_eager,
max_logprobs=self.target_model_config.max_logprobs, max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override, hf_overrides=SpeculativeConfig.hf_config_override,
config_format=self.target_model_config.config_format,
) )
# Automatically detect the method # Automatically detect the method
......
...@@ -65,22 +65,6 @@ class StructuredOutputsConfig: ...@@ -65,22 +65,6 @@ class StructuredOutputsConfig:
@model_validator(mode="after") @model_validator(mode="after")
def _validate_structured_output_config(self) -> Self: def _validate_structured_output_config(self) -> Self:
# Import here to avoid circular import
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
if self.reasoning_parser_plugin and len(self.reasoning_parser_plugin) > 3:
ReasoningParserManager.import_reasoning_parser(self.reasoning_parser_plugin)
valid_reasoning_parsers = ReasoningParserManager.list_registered()
if (
self.reasoning_parser != ""
and self.reasoning_parser not in valid_reasoning_parsers
):
raise ValueError(
f"invalid reasoning parser: {self.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
)
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"): if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
raise ValueError( raise ValueError(
"disable_any_whitespace is only supported for " "disable_any_whitespace is only supported for "
......
...@@ -10,7 +10,7 @@ import json ...@@ -10,7 +10,7 @@ import json
import pathlib import pathlib
import textwrap import textwrap
from collections.abc import Iterable, Mapping, Sequence, Set from collections.abc import Iterable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar from typing import TYPE_CHECKING, Any, Protocol, TypeVar
...@@ -322,3 +322,35 @@ def handle_deprecated( ...@@ -322,3 +322,35 @@ def handle_deprecated(
for new_name in new_names: for new_name in new_names:
setattr(config, new_name, old_val) setattr(config, new_name, old_val)
@dataclass
class Range:
"""
A range of numbers.
Inclusive of start, inclusive of end.
"""
start: int
end: int
def is_single_size(self) -> bool:
return self.start == self.end
def __contains__(self, size: int) -> bool:
# Inclusive of start, inclusive of end
return self.start <= size <= self.end
def __eq__(self, other: object) -> bool:
if not isinstance(other, Range):
return False
return self.start == other.start and self.end == other.end
def __hash__(self) -> int:
return hash((self.start, self.end))
def __str__(self) -> str:
return f"({self.start}, {self.end})"
def __repr__(self) -> str:
return self.__str__()
...@@ -27,6 +27,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri ...@@ -27,6 +27,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
from .attention import AttentionConfig
from .cache import CacheConfig from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
from .device import DeviceConfig from .device import DeviceConfig
...@@ -38,6 +39,7 @@ from .lora import LoRAConfig ...@@ -38,6 +39,7 @@ from .lora import LoRAConfig
from .model import ModelConfig from .model import ModelConfig
from .observability import ObservabilityConfig from .observability import ObservabilityConfig
from .parallel import ParallelConfig from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig from .structured_outputs import StructuredOutputsConfig
...@@ -65,7 +67,7 @@ class OptimizationLevel(IntEnum): ...@@ -65,7 +67,7 @@ class OptimizationLevel(IntEnum):
"""O0 : No optimization. no compilation, no cudagraphs, no other """O0 : No optimization. no compilation, no cudagraphs, no other
optimization, just starting up immediately""" optimization, just starting up immediately"""
O1 = 1 O1 = 1
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
cudagraphs""" cudagraphs"""
O2 = 2 O2 = 2
"""O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs.""" """O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
...@@ -192,6 +194,8 @@ class VllmConfig: ...@@ -192,6 +194,8 @@ class VllmConfig:
"""Device configuration.""" """Device configuration."""
load_config: LoadConfig = Field(default_factory=LoadConfig) load_config: LoadConfig = Field(default_factory=LoadConfig)
"""Load configuration.""" """Load configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
lora_config: LoRAConfig | None = None lora_config: LoRAConfig | None = None
"""LoRA configuration.""" """LoRA configuration."""
speculative_config: SpeculativeConfig | None = None speculative_config: SpeculativeConfig | None = None
...@@ -215,6 +219,8 @@ class VllmConfig: ...@@ -215,6 +219,8 @@ class VllmConfig:
You can specify the full compilation config like so: You can specify the full compilation config like so:
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
""" """
profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
"""Profiling configuration."""
kv_transfer_config: KVTransferConfig | None = None kv_transfer_config: KVTransferConfig | None = None
"""The configurations for distributed KV cache transfer.""" """The configurations for distributed KV cache transfer."""
kv_events_config: KVEventsConfig | None = None kv_events_config: KVEventsConfig | None = None
...@@ -279,6 +285,10 @@ class VllmConfig: ...@@ -279,6 +285,10 @@ class VllmConfig:
vllm_factors.append(self.load_config.compute_hash()) vllm_factors.append(self.load_config.compute_hash())
else: else:
vllm_factors.append("None") vllm_factors.append("None")
if self.attention_config:
vllm_factors.append(self.attention_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config: if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash()) vllm_factors.append(self.lora_config.compute_hash())
else: else:
...@@ -289,6 +299,8 @@ class VllmConfig: ...@@ -289,6 +299,8 @@ class VllmConfig:
vllm_factors.append("None") vllm_factors.append("None")
if self.structured_outputs_config: if self.structured_outputs_config:
vllm_factors.append(self.structured_outputs_config.compute_hash()) vllm_factors.append(self.structured_outputs_config.compute_hash())
if self.profiler_config:
vllm_factors.append(self.profiler_config.compute_hash())
else: else:
vllm_factors.append("None") vllm_factors.append("None")
vllm_factors.append(self.observability_config.compute_hash()) vllm_factors.append(self.observability_config.compute_hash())
...@@ -579,6 +591,15 @@ class VllmConfig: ...@@ -579,6 +591,15 @@ class VllmConfig:
else: else:
self.scheduler_config.async_scheduling = True self.scheduler_config.async_scheduling = True
if (
self.scheduler_config.async_scheduling
and not self.parallel_config.disable_nccl_for_dp_synchronization
):
logger.info(
"Disabling NCCL for DP synchronization when using async scheduling."
)
self.parallel_config.disable_nccl_for_dp_synchronization = True
from vllm.platforms import current_platform from vllm.platforms import current_platform
if ( if (
...@@ -671,36 +692,22 @@ class VllmConfig: ...@@ -671,36 +692,22 @@ class VllmConfig:
if current_platform.support_static_graph_mode(): if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support # if cudagraph_mode has full cudagraphs, we need to check support
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if (
# decode context parallel does not support full cudagraphs self.compilation_config.cudagraph_mode.has_full_cudagraphs()
if self.parallel_config.decode_context_parallel_size > 1: and self.model_config is not None
):
if self.model_config.pooler_config is not None:
logger.warning_once( logger.warning_once(
"Decode context parallel (DCP) is enabled, which is " "Pooling models do not support full cudagraphs. "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE." "Overriding cudagraph_mode to PIECEWISE."
) )
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# prefill context parallel do not support full cudagraphs elif self.model_config.is_encoder_decoder:
elif self.parallel_config.prefill_context_parallel_size > 1:
logger.warning_once( logger.warning_once(
"Prefill context parallel (PCP) is enabled, which is " "Encoder-decoder models do not support full cudagraphs. "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE." "Overriding cudagraph_mode to PIECEWISE."
) )
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config is not None:
if self.model_config.pooler_config is not None:
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution # disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager: if self.model_config is not None and self.model_config.enforce_eager:
...@@ -732,6 +739,8 @@ class VllmConfig: ...@@ -732,6 +739,8 @@ class VllmConfig:
"--kv-sharing-fast-prefill requires changes on model side for " "--kv-sharing-fast-prefill requires changes on model side for "
"correctness and to realize prefill savings. " "correctness and to realize prefill savings. "
) )
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self._set_compile_ranges()
if self.model_config and self.model_config.is_encoder_decoder: if self.model_config and self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -809,7 +818,10 @@ class VllmConfig: ...@@ -809,7 +818,10 @@ class VllmConfig:
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now." ), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode # Do this after all the updates to compilation_config.mode
self.compilation_config.set_splitting_ops_for_v1() self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend,
data_parallel_size=self.parallel_config.data_parallel_size,
)
if self.compilation_config.pass_config.enable_sp: if self.compilation_config.pass_config.enable_sp:
# With pipeline parallelism or dynamo partitioning, # With pipeline parallelism or dynamo partitioning,
...@@ -1035,8 +1047,14 @@ class VllmConfig: ...@@ -1035,8 +1047,14 @@ class VllmConfig:
self.compilation_config.max_cudagraph_capture_size self.compilation_config.max_cudagraph_capture_size
) )
if max_cudagraph_capture_size is None: if max_cudagraph_capture_size is None:
decode_query_len = 1
if (
self.speculative_config
and self.speculative_config.num_speculative_tokens
):
decode_query_len += self.speculative_config.num_speculative_tokens
max_cudagraph_capture_size = min( max_cudagraph_capture_size = min(
self.scheduler_config.max_num_seqs * 2, 512 self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
) )
max_num_tokens = self.scheduler_config.max_num_batched_tokens max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size) max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
...@@ -1133,6 +1151,52 @@ class VllmConfig: ...@@ -1133,6 +1151,52 @@ class VllmConfig:
# complete the remaining process. # complete the remaining process.
self.compilation_config.post_init_cudagraph_sizes() self.compilation_config.post_init_cudagraph_sizes()
def _set_compile_ranges(self):
"""
Set the compile ranges for the compilation config.
"""
compilation_config = self.compilation_config
computed_compile_ranges_split_points = []
# The upper bound of the compile ranges is the max_num_batched_tokens
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
if max_num_batched_tokens is not None:
computed_compile_ranges_split_points.append(max_num_batched_tokens)
# Add the compile ranges for flashinfer
if compilation_config.pass_config.fuse_allreduce_rms:
tp_size = self.parallel_config.tensor_parallel_size
max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
if max_size is not None:
max_token_num = max_size // (
self.model_config.get_hidden_size()
* self.model_config.dtype.itemsize
)
if (
max_num_batched_tokens is not None
and max_token_num < max_num_batched_tokens
):
computed_compile_ranges_split_points.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below allreduce-rms fusion threshold, "
"allreduce-rms fusion will be enabled for all num_tokens."
)
if compilation_config.compile_ranges_split_points is not None:
for x in compilation_config.compile_ranges_split_points:
assert isinstance(x, int)
assert x > 0, f"Invalid compile range split point: {x}"
if (
max_num_batched_tokens is not None
and x < max_num_batched_tokens
and x > 1
):
computed_compile_ranges_split_points.append(x)
compilation_config.compile_ranges_split_points = sorted(
computed_compile_ranges_split_points
)
def recalculate_max_model_len(self, max_model_len: int): def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config # Can only be called in try_verify_and_update_config
model_config = self.model_config model_config = self.model_config
......
...@@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device output_shape, dtype=input_tensor.dtype, device=input_tensor.device
) )
if sizes is not None: if sizes is not None and sizes.count(sizes[0]) != len(sizes):
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes) pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else: else:
pynccl_comm.reduce_scatter(output, input_tensor) pynccl_comm.reduce_scatter(output, input_tensor)
......
...@@ -27,6 +27,7 @@ from zmq import ( # type: ignore ...@@ -27,6 +27,7 @@ from zmq import ( # type: ignore
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup, sched_yield from vllm.distributed.utils import StatelessProcessGroup, sched_yield
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import ( from vllm.utils.network_utils import (
get_ip, get_ip,
get_open_port, get_open_port,
...@@ -632,7 +633,7 @@ class MessageQueue: ...@@ -632,7 +633,7 @@ class MessageQueue:
The MessageQueue instance for the calling process, The MessageQueue instance for the calling process,
and a list of handles (only non-empty for the reader process). and a list of handles (only non-empty for the reader process).
""" """
local_size = torch.cuda.device_count() local_size = current_platform.device_count()
rank = dist.get_rank() rank = dist.get_rank()
same_node = rank // local_size == reader_rank // local_size same_node = rank // local_size == reader_rank // local_size
buffer_io = MessageQueue( buffer_io = MessageQueue(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment