Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
......@@ -5,6 +5,7 @@ from collections.abc import Iterable
import torch.fx
from torch import SymInt
from torch.fx.experimental.symbolic_shapes import statically_known_true
from vllm.logger import init_logger
......@@ -116,12 +117,7 @@ class NoOpEliminationPass(VllmInductorPass):
2. The dimensions both correspond to the same SymInt
"""
# Case 1
if isinstance(i_dim, int) and isinstance(dim, int):
return dim == i_dim
# Case 2
if isinstance(i_dim, SymInt) and isinstance(dim, SymInt):
return dim == i_dim
return False
return statically_known_true(dim == i_dim)
def all_dims_equivalent(
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
......
......@@ -5,6 +5,7 @@ import functools
from torch import fx as fx
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var
from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import RMSNormQuantFusionPass
......@@ -24,7 +31,11 @@ if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .inductor_pass import (
CustomGraphPass,
InductorPass,
get_pass_context,
)
from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
......@@ -70,13 +81,13 @@ class PostGradPassManager(CustomGraphPass):
def __call__(self, graph: fx.Graph):
VllmInductorPass.dump_prefix = 0 # reset dump index
shape = get_pass_context().runtime_shape
compile_range = get_pass_context().compile_range
for pass_ in self.passes:
if pass_.is_applicable(shape):
if pass_.is_applicable_for_range(compile_range):
pass_(graph)
VllmInductorPass.dump_prefix += 1
else:
logger.debug("Skipping %s with shape %s", pass_, shape)
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
......@@ -105,8 +116,12 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]
......@@ -133,4 +148,8 @@ class PostGradPassManager(CustomGraphPass):
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
state["compile_range"] = str(get_pass_context().compile_range)
return InductorPass.hash_dict(state)
......@@ -7,18 +7,18 @@ from typing import Any
import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
class RangeEntry:
compile_range: Range
compiled: bool = False
runnable: Callable = None # type: ignore
......@@ -31,7 +31,6 @@ class PiecewiseBackend:
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend,
):
"""
......@@ -54,68 +53,131 @@ class PiecewiseBackend:
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1
# TODO: we need to generalize encoder compilation to other models
self.is_encoder_compilation = vllm_backend.prefix in [
"Qwen2_5_VisionPatchEmbed",
"Qwen2_5_VisionPatchMerger",
"Qwen2_5_VisionBlock",
]
self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation:
# For encoder compilation we use the max int32 value
# to set the upper bound of the compile ranges
max_int32 = 2**31 - 1
last_compile_range = self.compile_ranges[-1]
assert (
last_compile_range.end
== vllm_config.scheduler_config.max_num_batched_tokens
)
self.compile_ranges[-1] = Range(
start=last_compile_range.start, end=max_int32
)
self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes)
self.first_run_finished = False
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string)
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
self.compile_sizes = self.compilation_config.compile_sizes
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
logger.debug_once(log_string)
self.sym_shape_indices = sym_shape_indices
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}
# the entries for different shapes that we need to compile
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# to_be_compiled_ranges tracks the remaining ranges to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly.
for shape in self.compile_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
runnable=self.compiled_graph_for_general_shape,
for size in self.compile_sizes:
range = Range(start=size, end=size)
if range not in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)
def _fakify_args(self, args: list[Any]) -> list[Any]:
# We need to pass fake example_inputs, otherwise torch.compile
# will fakify the example_inputs potentially causing some non dynamic
# dimension to be be duck shaped to other existing shapes that have hints
# matching their values.
# This is problem because it can lead to unintended specializations!
# if the new wrongly dynamic dim is specialized
# it will force specializing the whole shape
# torch.compile probably should not accept
# non fake tensors as example inputs!
# See issue https://github.com/vllm-project/vllm/issues/27899
fake_example_inputs = []
for node in self.graph.graph.nodes:
# All place holders come first
if node.op == "placeholder":
fake_example_inputs.append(node.meta["example_value"])
else:
break
assert len(fake_example_inputs) == len(args)
return fake_example_inputs
def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
if not range_entry.compiled:
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
entry = self.concrete_size_entries[runtime_shape]
if not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
entry.runnable = self.vllm_backend.compiler_manager.compile(
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args = (
self._fakify_args(args)
if not range_entry.compile_range.is_single_size()
else args
)
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
)
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
self.check_for_ending_compilation()
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
else:
for range in self.compile_ranges:
if runtime_shape in range:
return self.range_entries[range]
return None
def __call__(self, *args) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} "
"[1, max_num_batched_tokens]"
)
return entry.runnable(*args)
self._maybe_compile_for_range_entry(range_entry, args)
return range_entry.runnable(*args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fusion import empty_bf16
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherSiluAndMul
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
AITER_RMS_ADD_GROUP_QUANT_OP = (
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
)
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
class AiterRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm & group fp8 quant custom
ops into an aiter rms_norm_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_GROUP_QUANT_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
return at[0], at[1]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class AiterFusedAddRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
into a aiter rms_norm_with_add_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_ADD_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
)
at2 = self.quant_op(at1[0], 128)
# result, scale, residual
return at2[0], at2[1], at1[1]
def replacement(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_ADD_GROUP_QUANT_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
# result, scale, residual
return at[0], at[1], at[2]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + dynamic group fp8 quant
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
self.patterns
)
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, quant_op
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
fusion_patterns = [
AiterRMSFp8GroupQuantPattern,
AiterFusedAddRMSFp8GroupQuantPattern,
]
return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
"""
def __init__(self, quant_op: OpOverload):
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
):
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
):
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
inputs = [
self.silu_and_mul_matcher.inputs()[0],
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)
......@@ -9,6 +9,7 @@ import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
......@@ -333,7 +334,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool:
def is_applicable_for_range(self, compile_range: Range) -> bool:
# When sequence parallelism is enabled, the residual tensor from RMSNorm
# needs to be split along the sequence dimension. However, this dimension
# is symbolic during piecewise compilation, and splitting symbolic shapes
......@@ -353,7 +354,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
):
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
......
......@@ -4,7 +4,7 @@
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from types import CodeType
from typing import Any
......@@ -13,7 +13,9 @@ import torch._C._dynamo.guards
import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
logger = init_logger(__name__)
......@@ -92,12 +94,29 @@ class TorchCompileWithNoGuardsWrapper:
return self.forward(*args, **kwargs)
def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs):
if self.layerwise_nvtx_tracing_enabled:
args_list = list(args)
kwargs_dict = dict(kwargs)
with layerwise_nvtx_marker_context(
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
self,
in_tensor=args_list,
kwargs=kwargs_dict,
) as ctx:
ctx.result = callable_fn(*args, **kwargs)
return ctx.result
return callable_fn(*args, **kwargs)
def __init__(self):
self.compiled = False
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
mode = vllm_config.compilation_config.mode
self.layerwise_nvtx_tracing_enabled = (
vllm_config.observability_config.enable_layerwise_nvtx_tracing
)
if mode is None:
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
......@@ -107,23 +126,49 @@ class TorchCompileWithNoGuardsWrapper:
if isinstance(backend, str) and backend == "inductor":
options = vllm_config.compilation_config.inductor_compile_config
self.first_compile = True
self.evaluate_guards = (
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
)
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
if mode != CompilationMode.STOCK_TORCH_COMPILE:
# Drop all the guards.
options["guard_filter_fn"] = lambda x: [False for _ in x]
if self.evaluate_guards:
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"compilation_config.dynamic_shapes_config.evaluate_guards "
"requires VLLM_USE_BYTECODE_HOOK=0. "
)
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
from vllm.compilation.decorators import DynamicShapesType
if envs.VLLM_USE_AOT_COMPILE:
# disabled until https://github.com/pytorch/pytorch/pull/169239
# is picked up.
assert ds_type != DynamicShapesType.BACKED, (
"evaluate_guards for backed shapes requires "
"VLLM_USE_AOT_COMPILE=False. "
)
options["guard_filter_fn"] = lambda x: [
entry.guard_type == "SHAPE_ENV" for entry in x
]
else:
options["guard_filter_fn"] = lambda x: [False for _ in x]
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
compiled_ptr: Any = self.forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
if ds_type == DynamicShapesType.UNBACKED:
if envs.VLLM_USE_BYTECODE_HOOK:
# reason is that bytecode does this hack torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation.
raise ValueError(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
)
# reason is that bytecode does torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation. And if we use
# compiled_ptr = self.check_invariants_and_forward
# it will reset all entries.
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
)
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
compiled_ptr = self.check_invariants_and_forward
if envs.VLLM_USE_AOT_COMPILE:
......@@ -168,13 +213,25 @@ class TorchCompileWithNoGuardsWrapper:
# Make sure a compilation is triggered by clearing dynamo
# cache.
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
return self._compiled_callable(*args, **kwargs)
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
else:
with self._dispatch_to_compiled_code():
return self.forward(*args, **kwargs)
return self._call_with_optional_nvtx_range(
self.forward, *args, **kwargs
)
else:
with _compilation_context():
return self._compiled_callable(*args, **kwargs)
ctx = (
nullcontext()
if self.first_compile or not self.evaluate_guards
else torch.compiler.set_stance("fail_on_recompile")
)
self.first_compile = False
with _compilation_context(), ctx:
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
@abstractmethod
def forward(self, *args, **kwargs): ...
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config.attention import AttentionConfig
from vllm.config.cache import CacheConfig
from vllm.config.compilation import (
CompilationConfig,
......@@ -23,6 +24,7 @@ from vllm.config.multimodal import MultiModalConfig
from vllm.config.observability import ObservabilityConfig
from vllm.config.parallel import EPLBConfig, ParallelConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.profiler import ProfilerConfig
from vllm.config.scheduler import SchedulerConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig
......@@ -46,6 +48,8 @@ from vllm.config.vllm import (
# __all__ should only contain classes and functions.
# Types and globals should be imported from their respective modules.
__all__ = [
# From vllm.config.attention
"AttentionConfig",
# From vllm.config.cache
"CacheConfig",
# From vllm.config.compilation
......@@ -86,6 +90,8 @@ __all__ = [
"SpeechToTextConfig",
# From vllm.config.structured_outputs
"StructuredOutputsConfig",
# From vllm.config.profiler
"ProfilerConfig",
# From vllm.config.utils
"ConfigType",
"SupportsMetricsInfo",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config
from vllm.logger import init_logger
logger = init_logger(__name__)
@config
@dataclass
class AttentionConfig:
"""Configuration for attention mechanisms in vLLM."""
backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically."""
flash_attn_version: Literal[2, 3] | None = None
"""Force vllm to use a specific flash-attention version (2 or 3).
Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False
"""Use separate prefill and decode kernels for attention instead of
the unified triton kernel."""
flash_attn_max_num_splits_for_cuda_graph: int = 32
"""Flash Attention max number splits for cuda graph decode."""
use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill."""
use_trtllm_ragged_deepseek_prefill: bool = False
"""Whether to use TRTLLM ragged deepseek prefill."""
use_trtllm_attention: bool | None = None
"""If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
disable_flashinfer_prefill: bool = False
"""Whether to disable flashinfer prefill."""
disable_flashinfer_q_quantization: bool = False
"""If set, when using fp8 kv, do not quantize Q to fp8."""
def compute_hash(self) -> str:
"""
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
from vllm.config.utils import get_hash_factors, hash_factors
ignored_factors: list[str] = []
factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)
@field_validator("backend", mode="before")
@classmethod
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return AttentionBackendEnum[value.upper()]
return value
def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
"""Set field from env var if set, with deprecation warning."""
from vllm import envs
if envs.is_set(env_var_name):
value = getattr(envs, env_var_name)
if field_name == "backend":
value = self.validate_backend_before(value)
setattr(self, field_name, value)
logger.warning_once(
"Using %s environment variable is deprecated and will be removed in "
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
"--attention-config.%s command line argument or "
"AttentionConfig(%s=...) config field instead.",
env_var_name,
field_name,
field_name,
)
def __post_init__(self) -> None:
self._set_from_env_if_set("backend", "VLLM_ATTENTION_BACKEND")
self._set_from_env_if_set("flash_attn_version", "VLLM_FLASH_ATTN_VERSION")
self._set_from_env_if_set(
"use_prefill_decode_attention", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION"
)
self._set_from_env_if_set(
"flash_attn_max_num_splits_for_cuda_graph",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
)
self._set_from_env_if_set("use_cudnn_prefill", "VLLM_USE_CUDNN_PREFILL")
self._set_from_env_if_set(
"use_trtllm_ragged_deepseek_prefill",
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
)
self._set_from_env_if_set("use_trtllm_attention", "VLLM_USE_TRTLLM_ATTENTION")
self._set_from_env_if_set(
"disable_flashinfer_prefill", "VLLM_DISABLE_FLASHINFER_PREFILL"
)
self._set_from_env_if_set(
"disable_flashinfer_q_quantization",
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
)
......@@ -29,8 +29,8 @@ CacheDType = Literal[
"fp8_inc",
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
MambaDType = Literal["auto", "float32", "float16"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
......@@ -77,9 +77,21 @@ class CacheConfig:
"""Whether to enable prefix caching."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n
- "sha256" uses Pickle for object serialization before hashing.\n
- "sha256" uses Pickle for object serialization before hashing. This is the
current default, as SHA256 is the most secure choice to avoid potential
hash collisions.\n
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
serializes objects using canonical CBOR and hashes them with SHA-256."""
serializes objects using canonical CBOR and hashes them with SHA-256.\n
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
non-cryptographic hashing. Requires the optional ``xxhash`` package.
IMPORTANT: Use of a hashing algorithm that is not considered
cryptographically secure theoretically increases the risk of hash collisions,
which can cause undefined behavior or even leak private information in
multi-tenant environments. Even if collisions are still very unlikely, it is
important to consider your security risk tolerance against the performance
benefits before turning this on.\n
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
reproducible hashing. Requires the optional ``xxhash`` package."""
cpu_offload_gb: float = Field(default=0, ge=0)
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
......
......@@ -4,7 +4,7 @@
import enum
from collections import Counter
from collections.abc import Callable
from dataclasses import asdict, field
from dataclasses import field
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal
......@@ -13,7 +13,13 @@ from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config, handle_deprecated
from vllm.config.utils import (
Range,
config,
get_hash_factors,
handle_deprecated,
hash_factors,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
......@@ -173,6 +179,9 @@ class PassConfig:
"""
MiB = 1024 * 1024
FI_SUPPORTED_WORLD_SIZES = [2, 4, 8]
if world_size not in FI_SUPPORTED_WORLD_SIZES:
return None
max_size_mb = self.fi_allreduce_fusion_max_size_mb
if max_size_mb is None:
max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size)
......@@ -196,7 +205,16 @@ class PassConfig:
Any new fields that affect compilation should be added to the hash.
Any future fields that don't affect compilation should be excluded.
"""
return InductorPass.hash_dict(asdict(self))
ignored_fields = [
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
]
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
@field_validator(
"fuse_norm_quant",
......@@ -267,14 +285,6 @@ class PassConfig:
"v0.13.0 or v1.0.0, whichever is sooner",
)
# Force old flags to None to ensure they are not used
self.enable_fusion = None
self.enable_attn_fusion = None
self.enable_noop = None
self.enable_sequence_parallelism = None
self.enable_async_tp = None
self.enable_fi_allreduce_fusion = None
if not self.eliminate_noops:
if self.fuse_norm_quant or self.fuse_act_quant:
logger.warning_once(
......@@ -334,7 +344,18 @@ class DynamicShapesConfig:
backed/unbacked.
"""
# TODO add a debug mode to fail
evaluate_guards: bool = False
"""
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
guarding on it. When True, dynamic shape guards are not dropped from dynamo.
And a failure will be triggered if a recompilation ever happens due to that.
This mode requires VLLM_USE_BYTECODE_HOOK to be 0.
Enabling this allow observing the dynamic shapes guards in the tlparse
artifacts also.
When type is backed, aot_compile must be disabled for this mode to work.
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
def compute_hash(self) -> str:
"""
......@@ -378,6 +399,8 @@ class CompilationConfig:
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- Inductor compilation:
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`compile_ranges_split_points`]
[vllm.config.CompilationConfig.compile_ranges_split_points]
- [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config]
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
......@@ -443,8 +466,8 @@ class CompilationConfig:
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation mode is 3, the backend supports both whole graph and piecewise
compilation, available backends include eager, inductor, and custom backends,
compilation mode is 3, the backend supports both whole graph and piecewise
compilation, available backends include eager, inductor, and custom backends,
the latter of which can be defined via `get_compile_backend`. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for
......@@ -491,6 +514,21 @@ class CompilationConfig:
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture."""
compile_ranges_split_points: list[int] | None = None
"""Split points that represent compile ranges for inductor.
The compile ranges are
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]].
If a range overlaps with the compile size, graph for compile size
will be prioritized, i.e. if we have a range [1, 8] and a compile size 4,
graph for compile size 4 will be compiled and used instead of the graph
for range [1, 8].
"""
inductor_compile_config: dict = field(default_factory=dict)
"""Additional configurations for inductor.
- None: use default configurations."""
......@@ -939,7 +977,9 @@ class CompilationConfig:
# May get recomputed in the model runner if adjustment is needed for spec-decode
self.compute_bs_to_padded_graph_size()
def set_splitting_ops_for_v1(self):
def set_splitting_ops_for_v1(
self, all2all_backend: str | None = None, data_parallel_size: int | None = None
):
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
# which currently only supports sequence parallelism in eager mode.
if self.mode != CompilationMode.VLLM_COMPILE:
......@@ -954,50 +994,83 @@ class CompilationConfig:
"mode is CompilationMode.VLLM_COMPILE"
)
if self.use_inductor_graph_partition:
self.set_splitting_ops_for_inductor_graph_partition()
return
added_default_splitting_ops = False
if self.pass_config.fuse_attn_quant:
# here use_inductor_graph_partition is False
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
self.set_splitting_ops_for_attn_fusion()
return
if self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture
# the full cudagraph outside the fx graph. This reduces some
# cpu overhead when the runtime batch_size is not cudagraph
# captured. see https://github.com/vllm-project/vllm/pull/20059
# for details. Make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty splitting_ops")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
else:
if self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture
# the full cudagraph outside the fx graph. This reduces some
# cpu overhead when the runtime batch_size is not cudagraph
# captured. see https://github.com/vllm-project/vllm/pull/20059
# for details. Make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
added_default_splitting_ops = True
elif len(self.splitting_ops) == 0:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not"
"contains piecewise cudagraph. Setting cudagraph_"
"mode to NONE. Hint: If you are using attention backends "
"that support cudagraph, consider manually setting "
"cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
"full cudagraphs."
"Using piecewise compilation with empty splitting_ops"
)
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not"
"contains piecewise cudagraph. Setting cudagraph_"
"mode to NONE. Hint: If you are using attention "
"backends that support cudagraph, consider manually "
"setting cudagraph_mode to FULL or FULL_DECODE_ONLY "
"to enable full cudagraphs."
)
self.cudagraph_mode = CUDAGraphMode.NONE
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do "
"not contains piecewise cudagraph. Setting "
"cudagraph_mode to FULL."
)
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
# split MoE ops for cudagraph
moe_ops = [
"vllm::moe_forward",
"vllm::moe_forward_shared",
]
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
dp_size = data_parallel_size if data_parallel_size is not None else 1
need_moe_splitting = (
backend == "deepep_high_throughput"
and dp_size > 1
# pure attn-fusion without inductor partition deliberately disables
# piecewise graphs and MoE splitting.
and not (
self.pass_config.fuse_attn_quant
and not self.use_inductor_graph_partition
)
)
if need_moe_splitting and self.cudagraph_mode != CUDAGraphMode.NONE:
# if we just initialized default splitting_ops for this config,
# automatically append the MoE ops
if added_default_splitting_ops:
for op in moe_ops:
if op not in self.splitting_ops:
self.splitting_ops.append(op)
# make sure MoE ops are split out
if not any(op in self.splitting_ops for op in moe_ops):
self.cudagraph_mode = CUDAGraphMode.NONE
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not "
"contains piecewise cudagraph. Setting cudagraph_mode "
"to FULL."
"DeepEP high throughput backend with data_parallel_size > 1 "
"requires splitting MoE ops from cudagraphs. Please ensure "
"'vllm::moe_forward' or 'vllm::moe_forward_shared' are "
"present in CompilationConfig.splitting_ops."
)
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
def set_splitting_ops_for_inductor_graph_partition(self):
assert self.use_inductor_graph_partition
if self.splitting_ops is None:
self.splitting_ops = list(self._attention_ops)
elif self.cudagraph_mode.has_full_cudagraphs():
# fall back to piecewise when MoE splitting is required.
self.cudagraph_mode = CUDAGraphMode.PIECEWISE
def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.fuse_attn_quant
......@@ -1152,3 +1225,13 @@ class CompilationConfig:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end
def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None:
return []
split_points = sorted(set(self.compile_ranges_split_points))
return [
Range(start=s + 1, end=e)
for s, e in zip([0] + split_points[:-1], split_points)
]
......@@ -4,7 +4,7 @@
import warnings
from collections.abc import Callable
from dataclasses import InitVar, field
from importlib.util import find_spec
from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch
......@@ -37,15 +37,13 @@ from vllm.transformers_utils.config import (
uses_xdrope_dim,
)
from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import (
is_gguf,
is_remote_gguf,
maybe_model_redirect,
maybe_patch_hf_config_from_gguf,
split_remote_gguf,
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
......@@ -86,7 +84,7 @@ TaskOption = Literal[
"transcription",
"draft",
]
TokenizerMode = Literal["auto", "hf", "slow", "mistral"]
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
......@@ -139,10 +137,12 @@ class ModelConfig:
name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto"
"""Tokenizer mode:\n
- "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.\n
- "auto" will use the tokenizer from `mistral_common` for Mistral models
if available, otherwise it will use the "hf" tokenizer.\n
- "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- Other custom values can be supported via plugins."""
trust_remote_code: bool = False
"""Trust remote code (e.g., from HuggingFace) when downloading the model
......@@ -471,18 +471,6 @@ class ModelConfig:
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
if (
(backend := envs.VLLM_ATTENTION_BACKEND)
and backend == "FLASHINFER"
and find_spec("flashinfer") is None
):
raise ValueError(
"VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer "
"module was not found. See "
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
"for instructions on how to install it."
)
from vllm.platforms import current_platform
if self.override_attention_dtype is not None and not current_platform.is_rocm():
......@@ -531,7 +519,11 @@ class ModelConfig:
if task == "classify":
return "classify"
if task == "reward":
return "reward"
logger.warning(
"Pooling models now default support all pooling; "
"you can use it without any settings."
)
return "embed"
if task == "score":
new_task = self._get_default_pooling_task(architectures)
return "classify" if new_task == "classify" else "embed"
......@@ -1233,6 +1225,19 @@ class ModelConfig:
)
return False
@cached_property
def is_mm_prefix_lm(self) -> bool:
"""Whether to use bidirectional attention for mm positions."""
MM_PREFIX_LM_MODELS = (
"gemma3",
# TODO(Isotr0py): Disable paligemma for now before
# we supports soft cap attention for FlexAttention
# "paligemma",
)
if not hasattr(self.hf_config, "model_type"):
return False
return self.hf_config.model_type in MM_PREFIX_LM_MODELS
def get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:
......@@ -1784,20 +1789,22 @@ class ModelConfig:
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
if pooling_type in ["mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support chunked prefill.",
pooling_type,
)
return False
else:
# pooling_type == "last"
elif pooling_type in ["all", "last"]:
logger.debug(
"Pooling models with causal attn and last pooling support "
"chunked prefill."
"Pooling models with causal attn and %s pooling support "
"chunked prefill.",
pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder"
......@@ -1821,20 +1828,22 @@ class ModelConfig:
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
if pooling_type in ["mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support prefix caching.",
pooling_type,
)
return False
else:
# pooling_type == "last"
elif pooling_type in ["all", "last"]:
logger.debug(
"Pooling models with causal attn and last pooling support "
"prefix caching."
"Pooling models with causal attn and %s pooling support "
"prefix caching.",
pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return False
......@@ -1897,8 +1906,8 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")),
("ClassificationModel", ("pooling", "classify")),
("ForRewardModeling", ("pooling", "reward")),
("RewardModel", ("pooling", "reward")),
("ForRewardModeling", ("pooling", "embed")),
("RewardModel", ("pooling", "embed")),
# Let other `*Model`s take priority
("Model", ("pooling", "embed")),
]
......
......@@ -55,6 +55,15 @@ class ObservabilityConfig:
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
cudagraph_metrics: bool = False
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
dispatch modes, and their observed frequencies at every logging interval)."""
enable_layerwise_nvtx_tracing: bool = False
"""Enable layerwise NVTX tracing. This traces the execution of each layer or
module in the model and attach informations such as input/output shapes to
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""
......
......@@ -35,6 +35,7 @@ logger = init_logger(__name__)
ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"]
@config
......@@ -65,6 +66,9 @@ class EPLBConfig:
Whether to use non-blocking EPLB.
"""
policy: EPLBPolicyOption = "default"
"""The policy type for expert parallel load balancing (EPLB)."""
@config
@dataclass
......@@ -180,13 +184,14 @@ class ParallelConfig:
distributed_executor_backend: (
str | DistributedExecutorBackend | type[Executor] | None
) = None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
of pipeline_parallel_size and tensor_parallel_size is less than
or equal to the number of GPUs available, "mp" will be used to
keep processing on a single host. Otherwise, this will default
to "ray" if Ray is installed and fail otherwise. Note that tpu
only support Ray for distributed inference."""
"""Backend to use for distributed model workers, either "ray" or "mp"
(multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size
is less than or equal to the number of GPUs available, "mp" will be used to
keep processing on a single host. Otherwise, an error will be raised. To use "mp"
you must also set nnodes, and to use "ray" you must manually set
distributed_executor_backend to "ray".
Note that tpu only support Ray for distributed inference."""
worker_cls: str = "auto"
"""The full name of the worker class to use. If "auto", the worker class
......@@ -562,8 +567,11 @@ class ParallelConfig:
):
gpu_count = cuda_device_count_stateless()
raise ValueError(
f"Tensor parallel size ({self.world_size}) cannot be "
f"larger than the number of available GPUs ({gpu_count})."
f"World size ({self.world_size}) is larger than the number of "
f"available GPUs ({gpu_count}) in this node. If this is "
"intentional and you are using:\n"
"- ray, set '--distributed-executor-backend ray'.\n"
"- multiprocessing, set '--nnodes' appropriately."
)
elif self.data_parallel_backend == "ray":
logger.info(
......@@ -593,10 +601,14 @@ class ParallelConfig:
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1:
allowed_backends = ("mp", "uni", "external_launcher")
if (
self.distributed_executor_backend not in allowed_backends
and self.nnodes > 1
):
raise ValueError(
"nnodes > 1 can only be set when distributed executor "
"backend is mp or uni."
"backend is mp, uni or external_launcher."
)
@property
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any, Literal
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
ProfilerKind = Literal["torch", "cuda"]
@config
@dataclass
class ProfilerConfig:
"""Dataclass which contains profiler config for the engine."""
profiler: ProfilerKind | None = None
"""Which profiler to use. Defaults to None. Options are:
- 'torch': Use PyTorch profiler.\n
- 'cuda': Use CUDA profiler."""
torch_profiler_dir: str = ""
"""Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
worker's traces (CPU & GPU) will be saved under this directory. Note that
it must be an absolute path."""
torch_profiler_with_stack: bool = True
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
torch_profiler_with_flops: bool = False
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
torch_profiler_use_gzip: bool = True
"""If `True`, saves torch profiler traces in gzip format. Enabled by default"""
torch_profiler_dump_cuda_time_total: bool = True
"""If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""
torch_profiler_record_shapes: bool = False
"""If `True`, records tensor shapes in the torch profiler. Disabled by default."""
torch_profiler_with_memory: bool = False
"""If `True`, enables memory profiling in the torch profiler.
Disabled by default."""
ignore_frontend: bool = False
"""If `True`, disables the front-end profiling of AsyncLLM when using the
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
since the front-end profiling does not track iterations and will capture the
entire range.
"""
delay_iterations: int = Field(default=0, ge=0)
"""Number of engine iterations to skip before starting profiling.
Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
"""
max_iterations: int = Field(default=0, ge=0)
"""Maximum number of engine iterations to profile after starting profiling.
Defaults to 0, meaning no limit.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
"""Get field from env var if set, with deprecation warning."""
if envs.is_set(env_var_name):
value = getattr(envs, env_var_name)
logger.warning_once(
"Using %s environment variable is deprecated and will be removed in "
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
"--profiler-config.%s command line argument or "
"ProfilerConfig(%s=...) config field instead.",
env_var_name,
field_name,
field_name,
)
return value
return None
def _set_from_env_if_set(
self,
field_name: str,
env_var_name: str,
to_bool: bool = True,
to_int: bool = False,
) -> None:
"""Set field from env var if set, with deprecation warning."""
value = self._get_from_env_if_set(field_name, env_var_name)
if value is not None:
if to_bool:
value = value == "1"
if to_int:
value = int(value)
setattr(self, field_name, value)
@model_validator(mode="after")
def _validate_profiler_config(self) -> Self:
maybe_use_cuda_profiler = self._get_from_env_if_set(
"profiler", "VLLM_TORCH_CUDA_PROFILE"
)
if maybe_use_cuda_profiler is not None:
self.profiler = "cuda" if maybe_use_cuda_profiler == "1" else None
else:
self._set_from_env_if_set(
"torch_profiler_dir", "VLLM_TORCH_PROFILER_DIR", to_bool=False
)
if self.torch_profiler_dir:
self.profiler = "torch"
self._set_from_env_if_set(
"torch_profiler_record_shapes",
"VLLM_TORCH_PROFILER_RECORD_SHAPES",
)
self._set_from_env_if_set(
"torch_profiler_with_memory",
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY",
)
self._set_from_env_if_set(
"torch_profiler_with_stack",
"VLLM_TORCH_PROFILER_WITH_STACK",
)
self._set_from_env_if_set(
"torch_profiler_with_flops",
"VLLM_TORCH_PROFILER_WITH_FLOPS",
)
self._set_from_env_if_set(
"ignore_frontend",
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM",
)
self._set_from_env_if_set(
"torch_profiler_use_gzip",
"VLLM_TORCH_PROFILER_USE_GZIP",
)
self._set_from_env_if_set(
"torch_profiler_dump_cuda_time_total",
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL",
)
self._set_from_env_if_set(
"delay_iterations", "VLLM_PROFILER_DELAY_ITERS", to_bool=False, to_int=True
)
self._set_from_env_if_set(
"max_iterations", "VLLM_PROFILER_MAX_ITERS", to_bool=False, to_int=True
)
has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
logger.warning_once(
"Using 'torch' profiler with delay_iterations or max_iterations "
"while ignore_frontend is False may result in high overhead."
)
profiler_dir = self.torch_profiler_dir
if profiler_dir and self.profiler != "torch":
raise ValueError(
"torch_profiler_dir is only applicable when profiler is set to 'torch'"
)
if self.profiler == "torch" and not profiler_dir:
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
if profiler_dir:
is_gs_path = (
profiler_dir.startswith("gs://")
and profiler_dir[5:]
and profiler_dir[5] != "/"
)
if not is_gs_path:
self.torch_profiler_dir = os.path.abspath(
os.path.expanduser(profiler_dir)
)
return self
......@@ -337,6 +337,7 @@ class SpeculativeConfig:
enforce_eager=self.target_model_config.enforce_eager,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
config_format=self.target_model_config.config_format,
)
# Automatically detect the method
......
......@@ -65,22 +65,6 @@ class StructuredOutputsConfig:
@model_validator(mode="after")
def _validate_structured_output_config(self) -> Self:
# Import here to avoid circular import
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
if self.reasoning_parser_plugin and len(self.reasoning_parser_plugin) > 3:
ReasoningParserManager.import_reasoning_parser(self.reasoning_parser_plugin)
valid_reasoning_parsers = ReasoningParserManager.list_registered()
if (
self.reasoning_parser != ""
and self.reasoning_parser not in valid_reasoning_parsers
):
raise ValueError(
f"invalid reasoning parser: {self.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
)
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
raise ValueError(
"disable_any_whitespace is only supported for "
......
......@@ -10,7 +10,7 @@ import json
import pathlib
import textwrap
from collections.abc import Iterable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
......@@ -322,3 +322,35 @@ def handle_deprecated(
for new_name in new_names:
setattr(config, new_name, old_val)
@dataclass
class Range:
"""
A range of numbers.
Inclusive of start, inclusive of end.
"""
start: int
end: int
def is_single_size(self) -> bool:
return self.start == self.end
def __contains__(self, size: int) -> bool:
# Inclusive of start, inclusive of end
return self.start <= size <= self.end
def __eq__(self, other: object) -> bool:
if not isinstance(other, Range):
return False
return self.start == other.start and self.end == other.end
def __hash__(self) -> int:
return hash((self.start, self.end))
def __str__(self) -> str:
return f"({self.start}, {self.end})"
def __repr__(self) -> str:
return self.__str__()
......@@ -27,6 +27,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
from vllm.utils.hashing import safe_hash
from .attention import AttentionConfig
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
from .device import DeviceConfig
......@@ -38,6 +39,7 @@ from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
......@@ -65,7 +67,7 @@ class OptimizationLevel(IntEnum):
"""O0 : No optimization. no compilation, no cudagraphs, no other
optimization, just starting up immediately"""
O1 = 1
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
cudagraphs"""
O2 = 2
"""O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
......@@ -192,6 +194,8 @@ class VllmConfig:
"""Device configuration."""
load_config: LoadConfig = Field(default_factory=LoadConfig)
"""Load configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
lora_config: LoRAConfig | None = None
"""LoRA configuration."""
speculative_config: SpeculativeConfig | None = None
......@@ -215,6 +219,8 @@ class VllmConfig:
You can specify the full compilation config like so:
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
"""
profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
"""Profiling configuration."""
kv_transfer_config: KVTransferConfig | None = None
"""The configurations for distributed KV cache transfer."""
kv_events_config: KVEventsConfig | None = None
......@@ -279,6 +285,10 @@ class VllmConfig:
vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.attention_config:
vllm_factors.append(self.attention_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
else:
......@@ -289,6 +299,8 @@ class VllmConfig:
vllm_factors.append("None")
if self.structured_outputs_config:
vllm_factors.append(self.structured_outputs_config.compute_hash())
if self.profiler_config:
vllm_factors.append(self.profiler_config.compute_hash())
else:
vllm_factors.append("None")
vllm_factors.append(self.observability_config.compute_hash())
......@@ -579,6 +591,15 @@ class VllmConfig:
else:
self.scheduler_config.async_scheduling = True
if (
self.scheduler_config.async_scheduling
and not self.parallel_config.disable_nccl_for_dp_synchronization
):
logger.info(
"Disabling NCCL for DP synchronization when using async scheduling."
)
self.parallel_config.disable_nccl_for_dp_synchronization = True
from vllm.platforms import current_platform
if (
......@@ -671,36 +692,22 @@ class VllmConfig:
if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
if self.parallel_config.decode_context_parallel_size > 1:
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and self.model_config is not None
):
if self.model_config.pooler_config is not None:
logger.warning_once(
"Decode context parallel (DCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# prefill context parallel do not support full cudagraphs
elif self.parallel_config.prefill_context_parallel_size > 1:
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config is not None:
if self.model_config.pooler_config is not None:
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
......@@ -732,6 +739,8 @@ class VllmConfig:
"--kv-sharing-fast-prefill requires changes on model side for "
"correctness and to realize prefill savings. "
)
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self._set_compile_ranges()
if self.model_config and self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -809,7 +818,10 @@ class VllmConfig:
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode
self.compilation_config.set_splitting_ops_for_v1()
self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend,
data_parallel_size=self.parallel_config.data_parallel_size,
)
if self.compilation_config.pass_config.enable_sp:
# With pipeline parallelism or dynamo partitioning,
......@@ -1035,8 +1047,14 @@ class VllmConfig:
self.compilation_config.max_cudagraph_capture_size
)
if max_cudagraph_capture_size is None:
decode_query_len = 1
if (
self.speculative_config
and self.speculative_config.num_speculative_tokens
):
decode_query_len += self.speculative_config.num_speculative_tokens
max_cudagraph_capture_size = min(
self.scheduler_config.max_num_seqs * 2, 512
self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
......@@ -1133,6 +1151,52 @@ class VllmConfig:
# complete the remaining process.
self.compilation_config.post_init_cudagraph_sizes()
def _set_compile_ranges(self):
"""
Set the compile ranges for the compilation config.
"""
compilation_config = self.compilation_config
computed_compile_ranges_split_points = []
# The upper bound of the compile ranges is the max_num_batched_tokens
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
if max_num_batched_tokens is not None:
computed_compile_ranges_split_points.append(max_num_batched_tokens)
# Add the compile ranges for flashinfer
if compilation_config.pass_config.fuse_allreduce_rms:
tp_size = self.parallel_config.tensor_parallel_size
max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
if max_size is not None:
max_token_num = max_size // (
self.model_config.get_hidden_size()
* self.model_config.dtype.itemsize
)
if (
max_num_batched_tokens is not None
and max_token_num < max_num_batched_tokens
):
computed_compile_ranges_split_points.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below allreduce-rms fusion threshold, "
"allreduce-rms fusion will be enabled for all num_tokens."
)
if compilation_config.compile_ranges_split_points is not None:
for x in compilation_config.compile_ranges_split_points:
assert isinstance(x, int)
assert x > 0, f"Invalid compile range split point: {x}"
if (
max_num_batched_tokens is not None
and x < max_num_batched_tokens
and x > 1
):
computed_compile_ranges_split_points.append(x)
compilation_config.compile_ranges_split_points = sorted(
computed_compile_ranges_split_points
)
def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config
model_config = self.model_config
......
......@@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
if sizes is not None:
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_tensor)
......
......@@ -27,6 +27,7 @@ from zmq import ( # type: ignore
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import (
get_ip,
get_open_port,
......@@ -632,7 +633,7 @@ class MessageQueue:
The MessageQueue instance for the calling process,
and a list of handles (only non-empty for the reader process).
"""
local_size = torch.cuda.device_count()
local_size = current_platform.device_count()
rank = dist.get_rank()
same_node = rank // local_size == reader_rank // local_size
buffer_io = MessageQueue(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment