Commit a99300bd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev

parents cc3e01c7 5438967f
......@@ -434,6 +434,14 @@ def validate_args(args):
if args.backend == "mii" and args.tokenizer != args.model:
raise ValueError(
"Tokenizer must be the same as the model for MII backend.")
# --data-parallel is not supported currently.
# https://github.com/vllm-project/vllm/issues/16222
if args.data_parallel_size > 1:
raise ValueError(
"Data parallel is not supported in offline benchmark, "
"please use benchmark serving instead"
)
def add_cli_args(parser: argparse.ArgumentParser):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
register_replacement)
from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
def silu_mul_pattern_static(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at1[1],
scale=scale)
return at2[1]
# FUSED_OPS: dict[QuantKey, OpOverload] = {
# kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
# }
# silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr(
# torch.ops._C, "silu_and_mul_nvfp4_quant"))
# if silu_and_mul_nvfp4_quant_supported:
# FUSED_OPS[
# kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
def silu_mul_replacement_static(result: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
result=result,
input=input,
scale=scale)
return at[1]
class ActivationQuantPattern(ABC):
"""
The base class for Activation+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
quant_key: QuantKey,
):
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key]
assert self.quant_key in FUSED_OPS, \
f"unsupported fusion scheme {self.quant_key}"
self.FUSED_OP = FUSED_OPS[self.quant_key]
def empty_fp8(*args, **kwargs):
fp8 = current_platform.fp8_dtype()
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
def empty_quant(self, *args, **kwargs):
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@abstractmethod
def register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError
def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def __init__(self, symmetric: bool = True):
quant_key = QuantKey(dtype=FP8_DTYPE,
scale=kStaticTensorScale,
symmetric=symmetric)
super().__init__(quant_key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at1 = auto_functionalized(SILU_MUL_OP,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(self.QUANT_OP,
result=result,
input=at1[1],
scale=scale)
return at2[1]
def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
scale=scale)
return at[1]
inputs = [
self.empty_quant(5, 4), # result
empty_bf16(5, 4), # result_silu_mul
empty_bf16(5, 4), # input
empty_fp32(1, 1) # scale
]
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def __init__(self):
super().__init__(kNvfp4Quant)
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, output_scale: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(SILU_MUL_OP,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(self.QUANT_OP,
output=result,
input=at1[1],
output_scale=output_scale,
input_scale=scale)
return at2[1], at2[2]
def replacement(result: torch.Tensor, output_scale: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
result_block_scale=output_scale,
input=input,
input_global_scale=scale)
return at[1], at[2]
inputs = [
self.empty_quant(5, 32), # result
empty_i32(128, 4), # output_scale
empty_bf16(5, 64), # result_silu_mul
empty_bf16(5, 64), # input
empty_fp32(1, 1) # scale
]
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmInductorPass):
......@@ -61,21 +162,19 @@ class ActivationQuantFusionPass(VllmInductorPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="activation_quant_fusion_pass")
inputs = [
empty_fp8(5, 4), # Quant output
empty_bf16(5, 4), # Silu_and_mul output
empty_bf16(5, 4), # Input
empty_fp32(1, 1) # Scale
]
register_replacement(silu_mul_pattern_static,
silu_mul_replacement_static, inputs, fwd_only,
self.patterns)
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
pattern_silu_mul_fp8.register(self.patterns)
if silu_and_mul_nvfp4_quant_supported:
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)
def __call__(self, graph: torch.fx.Graph):
self.begin()
......@@ -87,3 +186,8 @@ class ActivationQuantFusionPass(VllmInductorPass):
self.dump_graph(graph, "after_act_quant_fusion")
self.end_and_log()
def uuid(self):
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern)
......@@ -271,7 +271,7 @@ def split_graph(graph: fx.GraphModule,
outputs.append(
SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
# sort by intetger graph_id, rather than string name
# sort by integer graph_id, rather than string name
outputs.sort(key=lambda x: x.graph_id)
return split_gm, outputs
......@@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: list[str], vllm_config: VllmConfig,
graph_pool, vllm_backend: "VllmBackend"):
vllm_backend: "VllmBackend"):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.vllm_config = vllm_config
self.vllm_backend = vllm_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
......@@ -359,7 +358,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
graph_pool=self.graph_pool,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
......@@ -405,7 +403,6 @@ class VllmBackend:
vllm_config: VllmConfig
compilation_config: CompilationConfig
graph_pool: Any
_called: bool = False
# the graph we compiled
graph: fx.GraphModule
......@@ -427,19 +424,12 @@ class VllmBackend:
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. launguage_model, vision_model, etc.
# e.g. language_model, vision_model, etc.
# when multiple parts are initialized as independent
# models, we need to use the model_tag to distinguish
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag
global_graph_pool = current_platform.get_global_graph_pool()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool
# Passes to run on the graph post-grad.
self.post_grad_pass_manager = PostGradPassManager()
......@@ -484,7 +474,7 @@ class VllmBackend:
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affects the computation graph.
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash = envs.compute_hash()
factors.append(env_hash)
......@@ -586,7 +576,7 @@ class VllmBackend:
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
self.vllm_config, self.graph_pool,
self.vllm_config,
self).run(*example_inputs)
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
......
......@@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol):
"""
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
runtime_mode: CUDAGraphMode, **kwargs):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
......@@ -25,9 +25,6 @@ class AbstractStaticGraphWrapper(Protocol):
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
......
......@@ -10,6 +10,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
......@@ -18,6 +19,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
FP8_DTYPE = current_platform.fp8_dtype()
......@@ -348,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
class AsyncTPPass(VllmInductorPass):
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
......@@ -401,6 +404,18 @@ if flashinfer_comm is not None:
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
}
try:
_FI_MAX_SIZES.update({
int(k): int(float(v) * MiB)
for k, v in
envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
})
except Exception as e:
raise ValueError(
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: "
+ str(e)) from e
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE = MiB // 2
......@@ -465,7 +480,8 @@ if flashinfer_comm is not None:
quant_out=quant_out,
scale_out=scale_out,
# in vllm we only support swizzled layout
layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED,
layout_code=flashinfer_comm.QuantizationSFLayout.
SWIZZLED_128x4,
scale_factor=scale_factor,
)
else:
......@@ -1107,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass):
# in fallback path, when we don't use flashinfer
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
self.register_patterns()
@enable_fake_mode
def register_patterns(self):
for epsilon in [1e-5, 1e-6]:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
......
......@@ -67,11 +67,9 @@ class CUDAGraphWrapper:
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
graph_pool: Any = None,
cudagraph_options: Optional[CUDAGraphOptions] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.graph_pool = graph_pool
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
......@@ -81,8 +79,10 @@ class CUDAGraphWrapper:
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
if self.graph_pool is None:
self.graph_pool = current_platform.get_global_graph_pool()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = current_platform.get_global_graph_pool()
if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()
......
......@@ -54,6 +54,14 @@ def _should_ignore_torch_compile(cls) -> bool:
return getattr(cls, IGNORE_COMPILE_KEY, False)
@overload
def support_torch_compile(
*,
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> Callable[[_T], _T]:
...
@overload
def support_torch_compile(
*,
......@@ -71,6 +79,7 @@ def support_torch_compile(
cls: Optional[_T] = None,
*,
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> Union[Callable[[_T], _T], _T]:
"""
A decorator to add support for compiling the forward method of a class.
......@@ -120,6 +129,11 @@ def support_torch_compile(
NOTE: if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
`enable_if` is a function that takes a `VllmConfig` object as input and
returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain
conditions are met.
"""
def cls_decorator_helper(cls: _T) -> _T:
......@@ -151,7 +165,8 @@ def support_torch_compile(
if k not in sig.parameters:
raise ValueError(
f"Argument {k} not found in the forward method of {cls}")
return _support_torch_compile(cls, inferred_dynamic_arg_dims)
return _support_torch_compile(cls, inferred_dynamic_arg_dims,
enable_if)
if cls is not None:
# use `support_torch_compile` as a decorator without arguments
......@@ -164,6 +179,7 @@ def support_torch_compile(
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, Union[int, list[int]]],
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
......@@ -184,13 +200,14 @@ def _support_torch_compile(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = \
vllm_config.compilation_config.level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo() or _should_ignore_torch_compile(
self.__class__)
self.__class__) or not enable_compile
if self.do_not_compile:
return
......@@ -273,8 +290,24 @@ def _support_torch_compile(
code.co_filename)
return inline_call(parent, func, args, kwargs)
# Disable the C++ compilation of symbolic shape guards. C++-fication
# of symbolic shape guards can improve guard overhead. But, since
# vllm skip guards anyways, setting this flag to False can improve
# compile time.
dynamo_config_patches = {}
try:
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
dynamo_config_patches[
"enable_cpp_symbolic_shape_guards"] = False
except AttributeError:
# Note: this config is not available in torch 2.6, we can skip
# if the config doesn't exist
logger.debug(
"enable_cpp_symbolic_shape_guards config not available")
with patch.object(InliningInstructionTranslator, 'inline_call',
patched_inline_call):
patched_inline_call), torch._dynamo.config.patch(
**dynamo_config_patches):
output = self.compiled_callable(*args, **kwargs)
return output
......
......@@ -9,6 +9,7 @@ import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
......@@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
def __call__(self, graph: torch.fx.Graph):
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
logger.debug("XPU platform does not support fix functionalization"
"pass currently.")
return
self.begin()
self.dump_graph(graph, "before_fix_functionalization")
......@@ -89,6 +97,15 @@ class FixFunctionalizationPass(VllmInductorPass):
# node,
# mutated_args,
# args=('result', 'input', 'scale'))
# elif hasattr(
# torch.ops._C, "silu_and_mul_nvfp4_quant"
# ) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
# mutated_args = {1: 'result', 2: 'result_block_scale'}
# self.defunctionalize(graph,
# node,
# mutated_args,
# args=('result', 'result_block_scale',
# 'input', 'input_global_scale'))
else:
continue # skip the count
......
......@@ -12,15 +12,18 @@ from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
from .inductor_pass import enable_fake_mode
from .multi_output_match import MultiOutputMatch
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def empty_bf16(*args, **kwargs):
......@@ -31,41 +34,12 @@ def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
def empty_i32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
"""
dtype: torch.dtype
static: bool
group_shape: GroupShape
symmetric: bool = True
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
f"{'a' if not self.symmetric else ''}symmetric)")
# kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
# kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
# kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8StaticTensorSym:
......@@ -75,6 +49,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8DynamicTokenSym:
# torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[
kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
class FusedRMSQuantKey(NamedTuple):
......@@ -187,11 +164,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
quant=QuantKey(dtype=quant_dtype,
scale=kStaticTensorScale,
symmetric=symmetric))
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass):
......@@ -244,11 +219,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
quant=QuantKey(dtype=quant_dtype,
scale=kStaticTensorScale,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
......@@ -337,10 +310,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=False,
group_shape=group_shape,
scale=scale,
symmetric=symmetric))
super().__init__(epsilon, key)
......@@ -435,10 +408,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=False,
group_shape=group_shape,
scale=scale,
symmetric=symmetric))
super().__init__(epsilon, key)
......@@ -556,6 +529,7 @@ class FusionPass(VllmInductorPass):
cls._instance.pass_config = config.compilation_config.pass_config
return cls._instance
@enable_fake_mode
def __init__(self, config: VllmConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from vllm.utils import round_up
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionStaticQuantPattern:
class AttentionQuantPattern(ABC):
"""
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
layer_name: str,
num_heads: int,
head_size: int,
quant_dtype: torch.dtype,
symmetric=True,
layer: Attention,
quant_key: QuantKey,
):
self.layer_name = layer_name
self.num_heads = num_heads
self.head_size = head_size
self.quant_dtype = quant_dtype
self.quant_key = QuantKey(dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric)
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key]
......@@ -48,31 +55,64 @@ class AttentionStaticQuantPattern:
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def register_if_supported(self, pm_pass: PatternMatcherPass,
layer: Attention):
if layer.impl.fused_output_quant_supported(self.quant_dtype,
self.quant_key.static,
self.quant_key.group_shape):
@staticmethod
def wrap_trace_fn(process_fx, trace_fn):
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(
self,
layer: Attention,
symmetric: bool = True,
):
quant_key = QuantKey(dtype=FP8_DTYPE,
scale=kStaticTensorScale,
symmetric=symmetric)
super().__init__(layer, quant_key)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor):
view_7 = RESHAPE_OP(output_attn,
[-1, self.num_heads, self.head_size])
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=view_7,
output=output_attn,
layer_name=self.layer_name,
output_scale=None)
attn_out_view = RESHAPE_OP(at1[1],
[-1, self.num_heads * self.head_size])
output_scale=None,
output_block_scale=None)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
result=output_quant,
input=attn_out_view,
......@@ -82,47 +122,116 @@ class AttentionStaticQuantPattern:
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor):
view_7 = RESHAPE_OP(output_quant,
[-1, self.num_heads, self.head_size])
# attn output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size],
0.0,
dtype=self.quant_dtype,
device=q.device)
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=view_7,
output=output_attn,
layer_name=self.layer_name,
output_scale=scale)
output_scale=scale,
output_block_scale=None)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with unset_fake_temporarily(), FakeTensorMode():
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads * self.head_size), # attn_output
self.empty_quant(5, self.num_heads *
self.head_size), # quant_output
empty_fp32(1, 1) # scale
]
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # attn_output
self.empty_quant(5,
self.num_heads * self.head_size), # quant_output
empty_fp32(1, 1) # scale
]
def wrap_trace_fn(process_fx, trace_fn):
pm.register_replacement(
pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Nvfp4Quant.
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Nvfp4Quant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention):
super().__init__(layer, kNvfp4Quant)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, input_scale: torch.Tensor):
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
output=output_quant,
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
pm.register_replacement(
pattern, replacement, inputs,
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, input_scale: torch.Tensor):
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
0.0,
dtype=self.quant_dtype,
device=q.device)
# attention output block scale
output_scale_view = torch.ops.aten.view.dtype(
output_scale, FP8_DTYPE)
at2 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view)
output = RESHAPE_OP(at2[1],
[-1, self.num_heads * self.head_size // 2])
return output, at2[2]
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size //
2), # output_quant
empty_i32(128, round_up(self.num_heads * self.head_size // 16,
4)), # output_scale
empty_fp32(1, 1), # input_scale
]
pm.register_replacement(
pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
class AttnFusionPass(VllmInductorPass):
......@@ -138,32 +247,42 @@ class AttnFusionPass(VllmInductorPass):
support are attention kernels, which need to support fusing output quant.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.static_fwd_ctx = config.compilation_config.static_forward_context
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
for key, layer in self.static_fwd_ctx.items():
pattern = AttentionStaticQuantPattern(key, layer.num_heads,
layer.head_size,
current_platform.fp8_dtype())
pattern.register_if_supported(self.patterns, layer)
if len(self.static_fwd_ctx) == 0:
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
pattern_fp8.register_if_supported(self.patterns)
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0:
logger.warning(
"Attention + quant fusion is enabled, but "
"CompilationConfig.static_forward_context is empty. "
"Cannot access attention layers so no fusion "
"patterns were registered.")
"Attention + quant fusion is enabled, but no attention layers "
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered.")
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.begin()
self.dump_graph(graph, "before_attn_fusion")
count = self.patterns.apply(graph)
# TODO: Move this to pass_manager.py after the fx graph broken issue
# has been resolved.
# see https://github.com/vllm-project/vllm/issues/23091
graph.eliminate_dead_code()
logger.debug("Fused quantization onto %s attention nodes", count)
self.dump_graph(graph, "after_attn_fusion")
self.end_and_log()
def uuid(self):
return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern)
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
AttentionFp8StaticQuantPattern,
AttentionNvfp4QuantPattern)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import hashlib
import inspect
import json
......@@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union
import torch
from torch import fx
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.utils import is_torch_equal_or_newer
......@@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass):
def uuid(self) -> Any:
return self._uuid
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@functools.wraps(fn)
def fn_new(*args, **kwargs) -> Any:
with torch._guards.tracing(
None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)
return result
return fn_new
......@@ -43,7 +43,7 @@ cudagraph_capturing_enabled: bool = True
def validate_cudagraph_capturing_enabled():
# used to monitor whether an cudagraph capturing is legal at runtime.
# used to monitor whether a cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
global cudagraph_capturing_enabled
......
......@@ -8,13 +8,13 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import FusionPass
from .fusion_attn import AttnFusionPass
if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
# from .activation_quant_fusion import ActivationQuantFusionPass
from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
......
......@@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
......@@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass):
performance.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
......
......@@ -36,7 +36,8 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo)
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
CUDAGraphMode, PassConfig)
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
ParallelConfig)
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger
......@@ -199,7 +200,17 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
yield a, b
a = b
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
try:
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
except (OSError, KeyError, TypeError):
# HACK: Python 3.13+ workaround - set missing __firstlineno__
# Workaround can be removed after we upgrade to pydantic==2.12.0
with open(inspect.getfile(cls)) as f:
for i, line in enumerate(f):
if f"class {cls.__name__}" in line and ":" in line:
cls.__firstlineno__ = i + 1
break
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
if not isinstance(cls_node, ast.ClassDef):
raise TypeError("Given object was not a class.")
......@@ -254,8 +265,14 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode = Literal["auto", "cpm", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
"processed_logits"]
MMEncoderTPMode = Literal["weights", "data"]
class LogprobsMode(enum.Enum):
RAW_LOGITS = "raw_logits"
RAW_LOGPROBS = "raw_logprobs"
PROCESSED_LOGITS = "processed_logits"
PROCESSED_LOGPROBS = "processed_logprobs"
@config
......@@ -359,12 +376,13 @@ class ModelConfig:
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
logprobs_mode: LogprobsMode = "raw_logprobs"
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
Raw means the values before applying any logit processors, like bad words.
Processed means the values after applying all processors, including
temperature and top_k/top_p.
"""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
......@@ -427,7 +445,7 @@ class ModelConfig:
from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
"""
mm_processor_cache_gb: int = 4
mm_processor_cache_gb: float = 4
"""The size (in GiB) of the multi-modal processor cache, which is used to
avoid re-processing past multi-modal inputs.
......@@ -436,6 +454,19 @@ class ModelConfig:
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
Set to `0` to disable this cache completely (not recommended)."""
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
......@@ -478,6 +509,8 @@ class ModelConfig:
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
"""One or more logits processors' fully-qualified class names or class
definitions"""
io_processor_plugin: Optional[str] = None
"""IOProcessor plugin name to load at model startup"""
enable_chunked_prefill: Optional[bool] = None
"""If True, prefill requests can be chunked based
......@@ -854,22 +887,25 @@ class ModelConfig:
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
if self._model_info.supports_multimodal:
if (self.mm_encoder_tp_mode == "data" and
not self._model_info.supports_multimodal_encoder_tp_data):
logger.warning_once(
"This model does not support `--mm-encoder-tp-mode data`. "
"Falling back to `--mm-encoder-tp-mode weights`.")
self.mm_encoder_tp_mode = "weights"
return MultiModalConfig(
limit_per_prompt=self.limit_mm_per_prompt,
media_io_kwargs=self.media_io_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
interleave_mm_strings=self.interleave_mm_strings,
skip_mm_profiling=self.skip_mm_profiling)
skip_mm_profiling=self.skip_mm_profiling,
)
return None
def set_mm_processor_cache_gb(self, value: int) -> None:
mm_config = self.get_multimodal_config()
self.mm_processor_cache_gb = value
mm_config.mm_processor_cache_gb = value
def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
......@@ -1099,10 +1135,22 @@ class ModelConfig:
def _verify_quantization(self) -> None:
supported_quantization = me_quant.QUANTIZATION_METHODS
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc",
"slimquant_w4a8", "slimquant_w4a8_marlin"
"fp8",
"modelopt",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"fbgemm_fp8",
"compressed-tensors",
"experts_int8",
"quark",
"modelopt_fp4",
"bitblas",
"gptq_bitblas",
"inc",
"petit_nvfp4",
"slimquant_w4a8",
"slimquant_w4a8_marlin"
]
if self.quantization is not None:
self.quantization = cast(me_quant.QuantizationMethods,
......@@ -1125,7 +1173,6 @@ class ModelConfig:
# `override_quantization_method` method) must be checked in order
# of preference (this is particularly important for GPTQ).
overrides = [
"marlin",
"bitblas",
"gptq_marlin_24",
"gptq_marlin",
......@@ -1136,6 +1183,7 @@ class ModelConfig:
"slimquant_w4a8_marlin"
"modelopt",
"modelopt_fp4",
"petit_nvfp4",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
......@@ -1470,7 +1518,8 @@ class ModelConfig:
from vllm.distributed.utils import get_pp_indices
if (self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"
or self.hf_config.model_type == "glm4_moe_mtp"):
or self.hf_config.model_type == "glm4_moe_mtp"
or self.hf_config.model_type == "ernie_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
......@@ -1670,29 +1719,8 @@ class ModelConfig:
return self.multimodal_config is not None
@property
def processor_return_mm_hashes(self) -> bool:
"""Whether the multi-modal processor should output hashes."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return mm_config.mm_processor_cache_gb > 0
@property
def enable_mm_processor_cache(self) -> bool:
"""Whether the multi-modal processor cache should be enabled."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return mm_config.mm_processor_cache_gb > 0
def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config
if mm_config is None:
return 0
return envs.VLLM_MM_INPUT_CACHE_GIB
def is_multimodal_raw_input_only_model(self) -> bool:
return self._model_info.supports_multimodal_raw_input_only
@property
def is_cross_encoder(self) -> bool:
......@@ -1703,10 +1731,6 @@ class ModelConfig:
def is_pp_supported(self) -> bool:
return self._model_info.supports_pp
@property
def is_multimodal_raw_input_supported(self) -> bool:
return self._model_info.supports_multimodal_raw_input
@property
def is_attention_free(self) -> bool:
return self._model_info.is_attention_free
......@@ -1917,7 +1941,8 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp"]
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp"]
@config
......@@ -2050,6 +2075,16 @@ class SpeculativeConfig:
"architectures": ["Glm4MoeMTPModel"]
})
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
return hf_config
return hf_config
def __post_init__(self):
......@@ -2068,8 +2103,8 @@ class SpeculativeConfig:
if self.target_model_config and \
(self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or
self.target_model_config.hf_text_config.model_type \
== "mimo"):
self.target_model_config.hf_text_config.model_type in
("mimo","ernie4_5_moe")):
# use the draft model from the same model:
self.model = self.target_model_config.model
elif self.method in ("ngram", "[ngram]"):
......@@ -2167,6 +2202,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
......@@ -2386,7 +2430,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp")
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
def __repr__(self) -> str:
method = self.method
......@@ -2422,8 +2466,8 @@ class LoRAConfig:
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
lora_extra_vocab_size: int = 256
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
(added to the base model vocabulary)."""
"""(Deprecated) Maximum size of extra vocabulary that can be present in a
LoRA adapter. Will be removed in v0.12.0."""
lora_vocab_padding_size: ClassVar[int] = current_platform\
.get_lora_vocab_padding_size()
......@@ -2465,6 +2509,12 @@ class LoRAConfig:
return hash_str
def __post_init__(self):
# Deprecation warning for lora_extra_vocab_size
logger.warning(
"`lora_extra_vocab_size` is deprecated and will be removed "
"in v0.12.0. Additional vocabulary support for "
"LoRA adapters is being phased out.")
# Setting the maximum rank to 512 should be able to satisfy the vast
# majority of applications.
possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512)
......@@ -2529,7 +2579,7 @@ class MultiModalConfig:
`{"num_crops": 4}`.
"""
mm_processor_cache_gb: int = 4
mm_processor_cache_gb: float = 4
"""
The size (in GiB) of the multi-modal processor cache, which is used to
......@@ -2540,6 +2590,22 @@ class MultiModalConfig:
Set to `0` to disable this cache completely (not recommended).
"""
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""
Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP.
"""
interleave_mm_strings: bool = False
"""
Enable fully interleaved support for multimodal prompts.
......@@ -2547,7 +2613,7 @@ class MultiModalConfig:
skip_mm_profiling: bool = False
"""
When enabled, skips multimodal memory profiling and only profiles with
When enabled, skips multimodal memory profiling and only profiles with
language backbone model during engine initialization.
This reduces engine startup time but shifts the responsibility to users for
......@@ -2610,24 +2676,24 @@ class PoolerConfig:
## for embeddings models
normalize: Optional[bool] = None
"""
Whether to normalize the embeddings outputs.
Whether to normalize the embeddings outputs.
"""
dimensions: Optional[int] = None
"""
Reduce the dimensions of embeddings if model
Reduce the dimensions of embeddings if model
support matryoshka representation.
"""
## for classification models
activation: Optional[bool] = None
"""
Whether to apply activation function to the classification outputs.
Whether to apply activation function to the classification outputs.
"""
## for reward models
softmax: Optional[bool] = None
"""
Whether to apply softmax to the reward outputs.
Whether to apply softmax to the reward outputs.
"""
step_tag_id: Optional[int] = None
"""
......@@ -2653,9 +2719,9 @@ class PoolerConfig:
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
This parameter enables accepting long inputs without requiring
This parameter enables accepting long inputs without requiring
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
max_embed_len, it will be handled according to the original max_model_len
validation logic. Defaults to None (i.e. set to max_model_len).
......@@ -3009,7 +3075,8 @@ def get_served_model_name(model: str,
return served_model_name
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"]
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines",
"lm-format-enforcer"]
@config
......@@ -3572,7 +3639,7 @@ class VllmConfig:
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.is_cuda_alike():
if current_platform.is_cuda_alike() or current_platform.is_xpu():
# if cudagraph_mode is not explicitly set by users, set default
# value
if self.compilation_config.cudagraph_mode is None:
......
......@@ -115,8 +115,8 @@ class CacheConfig:
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
some layers can skip tokens corresponding to prefill. This flag enables
attention metadata for eligible layers to be overriden with metadata
necessary for implementating this optimization in some models (e.g. Gemma3n)
attention metadata for eligible layers to be overridden with metadata
necessary for implementing this optimization in some models (e.g. Gemma3n)
"""
def compute_hash(self) -> str:
......@@ -145,12 +145,19 @@ class CacheConfig:
self._verify_cache_dtype()
self._verify_prefix_caching()
self._verify_kv_sharing_fast_prefill()
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_kv_sharing_fast_prefill(self) -> None:
if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1:
raise NotImplementedError(
"Fast prefill optimization for KV sharing is not supported "
"in V0 currently.")
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.cpu_offload_gb < 0:
......@@ -162,11 +169,6 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
if self.kv_sharing_fast_prefill:
logger.warning_once(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)")
return self
def _verify_cache_dtype(self) -> None:
......
......@@ -225,7 +225,8 @@ class CompilationConfig:
# CudaGraph compilation
cudagraph_mode: Optional[CUDAGraphMode] = None
"""
The mode of the cudagraph.
The mode of the cudagraph:
- NONE, no cudagraph capture.
- PIECEWISE. (v1 default)
- FULL.
......@@ -336,6 +337,9 @@ class CompilationConfig:
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
]
def compute_hash(self) -> str:
......@@ -382,13 +386,10 @@ class CompilationConfig:
if pass_config_exclude:
exclude["pass_config"] = pass_config_exclude
# The cast to string is necessary because Pydantic is mocked in docs
# builds and sphinx-argparse doesn't know the return type of decode()
return str(
TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode())
return TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode()
__str__ = __repr__
......
......@@ -15,7 +15,7 @@ import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, get_open_port
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
if TYPE_CHECKING:
from ray.runtime_env import RuntimeEnv
......@@ -32,6 +32,31 @@ logger = init_logger(__name__)
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
@config
@dataclass
class EPLBConfig:
"""Configuration for Expert Parallel Load Balancing (EP)."""
window_size: int = 1000
"""Window size for expert load recording."""
step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `lb_window_size` steps will be used for rearranging experts.
"""
num_redundant_experts: int = 0
"""Number of redundant experts to use for expert parallelism."""
log_balancedness: bool = False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
@config
@dataclass
class ParallelConfig:
......@@ -75,22 +100,24 @@ class ParallelConfig:
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
"""Enable expert parallelism load balancing for MoE layers."""
num_redundant_experts: int = 0
"""Number of redundant experts to use for expert parallelism."""
eplb_window_size: int = 1000
"""Window size for expert load recording."""
eplb_step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `eplb_window_size` steps will be used for rearranging experts.
"""
eplb_log_balancedness: bool = False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
eplb_config: EPLBConfig = field(default_factory=EPLBConfig)
"""Expert parallelism configuration."""
num_redundant_experts: Optional[int] = None
"""`num_redundant_experts` is deprecated and has been replaced with
`eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
Please use `eplb_config.num_redundant_experts` instead."""
eplb_window_size: Optional[int] = None
"""`eplb_window_size` is deprecated and has been replaced with
`eplb_config.window_size`. This will be removed in v0.12.0.
Please use `eplb_config.window_size` instead."""
eplb_step_interval: Optional[int] = None
"""`eplb_step_interval` is deprecated and has been replaced with
`eplb_config.step_interval`. This will be removed in v0.12.0.
Please use `eplb_config.step_interval` instead."""
eplb_log_balancedness: Optional[bool] = None
"""`eplb_log_balancedness` is deprecated and has been replaced with
`eplb_config.log_balancedness`. This will be removed in v0.12.0.
Please use `eplb_config.log_balancedness` instead."""
max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallel loading workers when loading model
......@@ -109,7 +136,8 @@ class ParallelConfig:
placement_group: Optional[PlacementGroup] = None
"""ray distributed model workers placement group."""
distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
distributed_executor_backend: Optional[Union[str,
DistributedExecutorBackend,
type[ExecutorBase]]] = None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
......@@ -137,9 +165,10 @@ class ParallelConfig:
rank: int = 0
"""Global rank in distributed setup."""
enable_multimodal_encoder_data_parallel: bool = False
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
_data_parallel_master_port_list: list[int] = field(default_factory=list)
"""List of open port auto-queried for data parallel messaging.
Set to be private as it's not intended to be configured by users.
"""
@property
def world_size_across_dp(self) -> int:
......@@ -153,11 +182,15 @@ class ParallelConfig:
processes that is related to data parallelism,
e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we
increment the port number each time we need to initialize a
new process group related to data parallelism.
pop a new port from the prepared port list each time we need to
initialize a new process group related to data parallelism.
"""
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
if self._data_parallel_master_port_list:
answer = self._data_parallel_master_port_list.pop()
else:
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
return answer
def stateless_init_dp_group(self) -> ProcessGroup:
......@@ -241,6 +274,38 @@ class ParallelConfig:
return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None:
# Forward deprecated fields to their new location
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = (
self.num_redundant_experts)
logger.warning_once(
"num_redundant_experts is deprecated and has been replaced "
"with eplb_config.num_redundant_experts. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_window_size is not None:
self.eplb_config.window_size = self.eplb_window_size
logger.warning_once(
"eplb_window_size is deprecated and has been replaced "
"with eplb_config.window_size. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_step_interval is not None:
self.eplb_config.step_interval = self.eplb_step_interval
logger.warning_once(
"eplb_step_interval is deprecated and has been replaced "
"with eplb_config.step_interval. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_log_balancedness is not None:
self.eplb_config.log_balancedness = self.eplb_log_balancedness
logger.warning_once(
"eplb_log_balancedness is deprecated and has been replaced "
"with eplb_config.log_balancedness. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
# Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size
......@@ -251,7 +316,10 @@ class ParallelConfig:
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port()
if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = \
self._data_parallel_master_port_list.pop()
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
raise ValueError(
......@@ -279,10 +347,10 @@ class ParallelConfig:
raise ValueError(
"Expert parallelism load balancing is only supported on "
"CUDA devices now.")
if self.num_redundant_experts < 0:
if self.eplb_config.num_redundant_experts < 0:
raise ValueError(
"num_redundant_experts must be non-negative, but got "
f"{self.num_redundant_experts}.")
f"{self.eplb_config.num_redundant_experts}.")
if not self.enable_expert_parallel:
raise ValueError(
"enable_expert_parallel must be True to use EPLB.")
......@@ -293,10 +361,10 @@ class ParallelConfig:
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
)
else:
if self.num_redundant_experts != 0:
if self.eplb_config.num_redundant_experts != 0:
raise ValueError(
"num_redundant_experts should be used with EPLB."
f"{self.num_redundant_experts}.")
f"{self.eplb_config.num_redundant_experts}.")
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
......@@ -342,23 +410,22 @@ class ParallelConfig:
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray)
and getattr(self.distributed_executor_backend, "uses_ray", False))
@model_validator(mode='after')
def _verify_args(self) -> Self:
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
if self.distributed_executor_backend not in (
"ray", "mp", "uni",
"external_launcher", None) and not (isinstance(
if self.distributed_executor_backend is not None and not isinstance(
self.distributed_executor_backend, str) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)):
raise ValueError(
"Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' 'uni', 'external_launcher' or"
" custom ExecutorBase subclass.")
"values are 'ray', 'mp' 'uni', 'external_launcher', "
" custom ExecutorBase subclass or its import path.")
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
......
......@@ -207,7 +207,7 @@ class NaiveBlockAllocator(BlockAllocator):
Args:
absolute_id (int): The absolute block id for the block
in whole allocator.
in whole allocator.
Returns:
int: The zero-offset block id on certain device.
......
......@@ -61,7 +61,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Args:
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids(Optional[Iterable[int]], optional): An optional iterable of
block_ids (Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment