Commit a99300bd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev

parents cc3e01c7 5438967f
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib import hashlib
import json
import os import os
import sys import sys
import tempfile import tempfile
...@@ -42,7 +43,6 @@ if TYPE_CHECKING: ...@@ -42,7 +43,6 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0 VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0
VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_OMP_THREADS_BIND: str = ""
...@@ -99,6 +99,7 @@ if TYPE_CHECKING: ...@@ -99,6 +99,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True
...@@ -131,7 +132,9 @@ if TYPE_CHECKING: ...@@ -131,7 +132,9 @@ if TYPE_CHECKING:
VLLM_TPU_USING_PATHWAYS: bool = False VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: str = "throughput" VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
...@@ -159,9 +162,12 @@ if TYPE_CHECKING: ...@@ -159,9 +162,12 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
# add envs # add envs
VLLM_OPTEST_URLS_PORT: Optional[int] = None VLLM_OPTEST_URLS_PORT: Optional[int] = None
...@@ -188,6 +194,7 @@ if TYPE_CHECKING: ...@@ -188,6 +194,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -491,11 +498,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -491,11 +498,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
# Pipeline stage partition strategy # Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION": "VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
...@@ -693,11 +695,14 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -693,11 +695,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR": "VLLM_LORA_RESOLVER_CACHE_DIR":
lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None), lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None),
# Enables torch profiler if set. Path to the directory where torch profiler # Enables torch profiler if set.
# traces are saved. Note that it must be an absolute path. # Both AsyncLLM's CPU traces as well as workers'
# traces (CPU & GPU) will be saved under this directory.
# Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR": "VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), .path.abspath(os.path.expanduser(os.getenv(
"VLLM_TORCH_PROFILER_DIR", ".")))),
# Enable torch profiler to record shapes if set # Enable torch profiler to record shapes if set
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will # VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will
...@@ -797,6 +802,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -797,6 +802,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")), ("true", "1")),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP8BMM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in
("true", "1")),
# use rocm skinny gemms # use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": "VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
...@@ -979,9 +990,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -979,9 +990,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
# E8M0 is faster on B200 but may reduce accuracy.
"VLLM_USE_DEEP_GEMM_E8M0": "VLLM_USE_DEEP_GEMM_E8M0":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))), lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
# TODO(wentao): unify the two E8M0 flags after verifying the correctness.
# Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs.
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no # JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine # JIT'ing in the hot-path. However, this warmup increases the engine
...@@ -990,6 +1004,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -990,6 +1004,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_SKIP_DEEP_GEMM_WARMUP": "VLLM_SKIP_DEEP_GEMM_WARMUP":
lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))), lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))),
# Whether to use fused grouped_topk used for MoE expert selection.
"VLLM_USE_FUSED_MOE_GROUPED_TOPK":
lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))),
# Allow use of FlashInfer MoE kernels for fused moe ops. # Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8": "VLLM_USE_FLASHINFER_MOE_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
...@@ -1068,6 +1086,16 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1068,6 +1086,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
# Specifies the thresholds of the communicated tensor sizes under which
# vllm should use flashinfer fused allreduce. The variable should be a
# JSON with the following format:
# { <world size>: <max size in mb> }
# Unspecified world sizes will fallback to
# { 2: 64, 4: 1, <everything else>: 0.5 }
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB":
lambda: json.loads(os.getenv(
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")),
# MoE routing strategy selector. # MoE routing strategy selector.
# See `RoutingSimulator.get_available_strategies()` # for available # See `RoutingSimulator.get_available_strategies()` # for available
# strategies. # strategies.
...@@ -1134,6 +1162,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1134,6 +1162,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_ATTENTION": "VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN":
lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
# Otherwise, uses the first available of: flashinfer cutlass GEMM, # Otherwise, uses the first available of: flashinfer cutlass GEMM,
# vllm cutlass GEMM, marlin GEMM. # vllm cutlass GEMM, marlin GEMM.
...@@ -1146,6 +1179,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1146,6 +1179,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_CUDAGRAPH_GC": "VLLM_ENABLE_CUDAGRAPH_GC":
lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))), lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))),
# Disable padding to CUDA graph capture batch sizes.
# TODO(wentao): https://github.com/vllm-project/vllm/issues/23378
# After the issue is fixed, we can remove this flag.
"VLLM_DISABLE_PAD_FOR_CUDAGRAPH":
lambda: bool(int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0"))),
# Used to force set up loopback IP # Used to force set up loopback IP
"VLLM_LOOPBACK_IP": "VLLM_LOOPBACK_IP":
lambda: os.getenv("VLLM_LOOPBACK_IP", ""), lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
...@@ -1179,6 +1218,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1179,6 +1218,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RESPONSES_API_STORE": "VLLM_ENABLE_RESPONSES_API_STORE":
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
# Allows vllm to find tuned config under customized folder # Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER": "VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
...@@ -1294,6 +1337,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1294,6 +1337,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_GLOBAL_CACHE13": "VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop for moe_fused_gate and moe_align_block_size
"VLLM_USE_LIGHT_OP":
lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
...@@ -1355,10 +1403,12 @@ def compute_hash() -> str: ...@@ -1355,10 +1403,12 @@ def compute_hash() -> str:
"VLLM_USE_AITER_UNIFIED_ATTENTION", "VLLM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ATTENTION_BACKEND", "VLLM_ATTENTION_BACKEND",
"VLLM_USE_FLASHINFER_SAMPLER", "VLLM_USE_FLASHINFER_SAMPLER",
"VLLM_FLASHINFER_FORCE_TENSOR_CORES",
"VLLM_DISABLED_KERNELS", "VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM", "VLLM_USE_DEEP_GEMM",
"VLLM_USE_DEEP_GEMM_E8M0",
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER",
"VLLM_USE_TRTLLM_FP4_GEMM", "VLLM_USE_TRTLLM_FP4_GEMM",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP8",
"VLLM_USE_FLASHINFER_MOE_FP4", "VLLM_USE_FLASHINFER_MOE_FP4",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
...@@ -1372,6 +1422,7 @@ def compute_hash() -> str: ...@@ -1372,6 +1422,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_RMSNORM", "VLLM_ROCM_USE_AITER_RMSNORM",
"VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA", "VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_SKINNY_GEMM", "VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING", "VLLM_ROCM_FP8_PADDING",
"VLLM_ROCM_MOE_PADDING", "VLLM_ROCM_MOE_PADDING",
......
...@@ -101,7 +101,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase): ...@@ -101,7 +101,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
result_handler.start() result_handler.start()
self.worker_monitor.start() self.worker_monitor.start()
# Set up signal handlers to shutdown the executor cleanly # Set up signal handlers to shut down the executor cleanly
# sometimes gc does not work well # sometimes gc does not work well
self.driver_worker = WorkerWrapperBase(self.vllm_config, 0) self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)
......
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
from array import array from array import array
from typing import Any, Type from typing import Any, Type
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any: def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types. """Custom msgspec enc hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
""" """
...@@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any: ...@@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any:
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}.") f"Given array has a type code of {obj.typecode}.")
return obj.tobytes() return obj.tobytes()
if isinstance(obj, MultiModalKwargs):
return dict(obj)
def decode_hook(type: Type, obj: Any) -> Any: def decode_hook(type: Type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types. """Custom msgspec dec hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
""" """
...@@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any: ...@@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj) deserialized.frombytes(obj)
return deserialized return deserialized
if type is MultiModalKwargs:
return MultiModalKwargs(obj)
...@@ -10,6 +10,7 @@ import msgspec ...@@ -10,6 +10,7 @@ import msgspec
import vllm.platforms import vllm.platforms
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed import get_pp_group
from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -136,6 +137,11 @@ try: ...@@ -136,6 +137,11 @@ try:
scheduler_output, intermediate_tensors) scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors): if isinstance(output, IntermediateTensors):
output = scheduler_output, output output = scheduler_output, output
elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
output = scheduler_output, None
return output return output
def override_env_vars(self, vars: Dict[str, str]): def override_env_vars(self, vars: Dict[str, str]):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs, ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
...@@ -18,6 +18,7 @@ target model. ...@@ -18,6 +18,7 @@ target model.
""" """
__all__ = [ __all__ = [
"DataPrompt",
"TextPrompt", "TextPrompt",
"TokensPrompt", "TokensPrompt",
"PromptType", "PromptType",
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs,
MultiModalUUIDDict)
class TextPrompt(TypedDict): class TextPrompt(TypedDict):
...@@ -30,6 +31,15 @@ class TextPrompt(TypedDict): ...@@ -30,6 +31,15 @@ class TextPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them. to pass the mm_processor_kwargs to each of them.
""" """
multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching, and MUST be unique per
multimodal item.
"""
cache_salt: NotRequired[str] cache_salt: NotRequired[str]
""" """
Optional cache salt to be used for prefix caching. Optional cache salt to be used for prefix caching.
...@@ -59,6 +69,14 @@ class TokensPrompt(TypedDict): ...@@ -59,6 +69,14 @@ class TokensPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them. to pass the mm_processor_kwargs to each of them.
""" """
multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching.
"""
cache_salt: NotRequired[str] cache_salt: NotRequired[str]
""" """
Optional cache salt to be used for prefix caching. Optional cache salt to be used for prefix caching.
...@@ -77,6 +95,16 @@ class EmbedsPrompt(TypedDict): ...@@ -77,6 +95,16 @@ class EmbedsPrompt(TypedDict):
""" """
class DataPrompt(TypedDict):
"""Represents generic inputs handled by IO processor plugins."""
data: Any
"""The input data"""
data_format: str
"""The input data format"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
""" """
Set of possible schemas for a single prompt: Set of possible schemas for a single prompt:
...@@ -174,9 +202,6 @@ class TokenInputs(TypedDict): ...@@ -174,9 +202,6 @@ class TokenInputs(TypedDict):
prompt_token_ids: list[int] prompt_token_ids: list[int]
"""The token IDs of the prompt.""" """The token IDs of the prompt."""
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
prompt: NotRequired[str] prompt: NotRequired[str]
""" """
The original prompt text corresponding to the token IDs, if available. The original prompt text corresponding to the token IDs, if available.
...@@ -190,7 +215,6 @@ class TokenInputs(TypedDict): ...@@ -190,7 +215,6 @@ class TokenInputs(TypedDict):
def token_inputs( def token_inputs(
prompt_token_ids: list[int], prompt_token_ids: list[int],
token_type_ids: Optional[list[int]] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
) -> TokenInputs: ) -> TokenInputs:
...@@ -200,8 +224,6 @@ def token_inputs( ...@@ -200,8 +224,6 @@ def token_inputs(
if prompt is not None: if prompt is not None:
inputs["prompt"] = prompt inputs["prompt"] = prompt
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
if cache_salt is not None: if cache_salt is not None:
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
......
...@@ -11,8 +11,9 @@ from vllm.config import ModelConfig ...@@ -11,8 +11,9 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs) MultiModalInputs, MultiModalUUIDDict)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
...@@ -32,12 +33,14 @@ class InputPreprocessor: ...@@ -32,12 +33,14 @@ class InputPreprocessor:
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup], tokenizer: Optional[TokenizerGroup],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model_config = model_config self.model_config = model_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup: def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None: if self.tokenizer is None:
...@@ -257,7 +260,9 @@ class InputPreprocessor: ...@@ -257,7 +260,9 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]], mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
...@@ -265,17 +270,22 @@ class InputPreprocessor: ...@@ -265,17 +270,22 @@ class InputPreprocessor:
""" """
tokenizer = self._get_mm_tokenizer(lora_request) tokenizer = self._get_mm_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config, mm_processor = self.mm_registry.create_processor(
tokenizer=tokenizer) self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
return mm_processor.apply(prompt, return mm_processor.apply(
mm_data, prompt,
hf_processor_mm_kwargs=mm_processor_kwargs, mm_data,
tokenization_kwargs=tokenization_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
return_mm_hashes=return_mm_hashes) tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
async def _process_multimodal_async( async def _process_multimodal_async(
self, self,
...@@ -284,7 +294,9 @@ class InputPreprocessor: ...@@ -284,7 +294,9 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]], mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Async version of Async version of
...@@ -292,16 +304,22 @@ class InputPreprocessor: ...@@ -292,16 +304,22 @@ class InputPreprocessor:
""" """
tokenizer = await self._get_mm_tokenizer_async(lora_request) tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config, mm_processor = self.mm_registry.create_processor(
tokenizer=tokenizer) self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
return mm_processor.apply(prompt, return mm_processor.apply(
mm_data, prompt,
hf_processor_mm_kwargs=mm_processor_kwargs, mm_data,
tokenization_kwargs=tokenization_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
return_mm_hashes=return_mm_hashes) tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
def _process_embeds( def _process_embeds(
self, self,
...@@ -333,15 +351,33 @@ class InputPreprocessor: ...@@ -333,15 +351,33 @@ class InputPreprocessor:
) -> EmbedsInputs: ) -> EmbedsInputs:
return self._process_embeds(parsed_content) return self._process_embeds(parsed_content)
def _truncate_inputs(
self,
inputs: list[int],
tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]:
if not tokenization_kwargs or "truncation" not in \
tokenization_kwargs or self.tokenizer is None:
return inputs
max_length = tokenization_kwargs["max_length"]
if self.tokenizer.truncation_side == "left":
return inputs[-max_length:]
else:
return inputs[:max_length]
def _process_tokens( def _process_tokens(
self, self,
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"] prompt_token_ids = self._truncate_inputs(
token_type_ids = parsed_content.get("token_type_ids") parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs] inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"): if multi_modal_data := parsed_content.get("multi_modal_data"):
...@@ -351,13 +387,10 @@ class InputPreprocessor: ...@@ -351,13 +387,10 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
else: else:
inputs = token_inputs( inputs = token_inputs(prompt_token_ids=prompt_token_ids)
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if cache_salt := parsed_content.get("cache_salt"): if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
...@@ -369,10 +402,12 @@ class InputPreprocessor: ...@@ -369,10 +402,12 @@ class InputPreprocessor:
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"] prompt_token_ids = self._truncate_inputs(
token_type_ids = parsed_content.get("token_type_ids") parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs] inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"): if multi_modal_data := parsed_content.get("multi_modal_data"):
...@@ -382,13 +417,10 @@ class InputPreprocessor: ...@@ -382,13 +417,10 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
else: else:
inputs = token_inputs( inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if cache_salt := parsed_content.get("cache_salt"): if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
...@@ -400,7 +432,9 @@ class InputPreprocessor: ...@@ -400,7 +432,9 @@ class InputPreprocessor:
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
...@@ -412,7 +446,7 @@ class InputPreprocessor: ...@@ -412,7 +446,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
else: else:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
...@@ -435,7 +469,9 @@ class InputPreprocessor: ...@@ -435,7 +469,9 @@ class InputPreprocessor:
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
...@@ -447,7 +483,7 @@ class InputPreprocessor: ...@@ -447,7 +483,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
else: else:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
...@@ -470,7 +506,9 @@ class InputPreprocessor: ...@@ -470,7 +506,9 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Extract the singleton inputs from a prompt. Extract the singleton inputs from a prompt.
...@@ -479,7 +517,6 @@ class InputPreprocessor: ...@@ -479,7 +517,6 @@ class InputPreprocessor:
* prompt: single encoder or decoder input prompt * prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts * lora_request: this is only valid for decoder prompts
* return_mm_hashes: whether to return multimodal hashes
Returns: Returns:
...@@ -493,21 +530,21 @@ class InputPreprocessor: ...@@ -493,21 +530,21 @@ class InputPreprocessor:
return self._process_tokens( return self._process_tokens(
parsed["content"], parsed["content"],
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return self._process_text( return self._process_text(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return self._process_text( return self._process_text(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
assert_never(parsed) assert_never(parsed)
...@@ -517,7 +554,9 @@ class InputPreprocessor: ...@@ -517,7 +554,9 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Async version of Async version of
...@@ -531,21 +570,21 @@ class InputPreprocessor: ...@@ -531,21 +570,21 @@ class InputPreprocessor:
return await self._process_tokens_async( return await self._process_tokens_async(
parsed["content"], parsed["content"],
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return await self._process_text_async( return await self._process_text_async(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return await self._process_text_async( return await self._process_text_async(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
assert_never(parsed) assert_never(parsed)
...@@ -655,6 +694,9 @@ class InputPreprocessor: ...@@ -655,6 +694,9 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
For encoder/decoder models only: For encoder/decoder models only:
...@@ -696,6 +738,7 @@ class InputPreprocessor: ...@@ -696,6 +738,7 @@ class InputPreprocessor:
encoder_inputs = self._prompt_to_llm_inputs( encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"], prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None decoder_inputs = None
...@@ -711,6 +754,7 @@ class InputPreprocessor: ...@@ -711,6 +754,7 @@ class InputPreprocessor:
inputs = self._prompt_to_llm_inputs( inputs = self._prompt_to_llm_inputs(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -726,6 +770,9 @@ class InputPreprocessor: ...@@ -726,6 +770,9 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
Async version of Async version of
...@@ -738,6 +785,7 @@ class InputPreprocessor: ...@@ -738,6 +785,7 @@ class InputPreprocessor:
encoder_task = self._prompt_to_llm_inputs_async( encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"], prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
...@@ -747,6 +795,7 @@ class InputPreprocessor: ...@@ -747,6 +795,7 @@ class InputPreprocessor:
decoder_task = self._prompt_to_llm_inputs_async( decoder_task = self._prompt_to_llm_inputs_async(
decoder_input, decoder_input,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
encoder_inputs, decoder_inputs = await asyncio.gather( encoder_inputs, decoder_inputs = await asyncio.gather(
...@@ -762,6 +811,7 @@ class InputPreprocessor: ...@@ -762,6 +811,7 @@ class InputPreprocessor:
inputs = await self._prompt_to_llm_inputs_async( inputs = await self._prompt_to_llm_inputs_async(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -788,7 +838,9 @@ class InputPreprocessor: ...@@ -788,7 +838,9 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
For decoder-only models: For decoder-only models:
...@@ -799,7 +851,6 @@ class InputPreprocessor: ...@@ -799,7 +851,6 @@ class InputPreprocessor:
* prompt: input prompt * prompt: input prompt
* lora_request * lora_request
* return_mm_hashes
Returns: Returns:
...@@ -810,7 +861,7 @@ class InputPreprocessor: ...@@ -810,7 +861,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
...@@ -820,7 +871,9 @@ class InputPreprocessor: ...@@ -820,7 +871,9 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
Async version of Async version of
...@@ -830,7 +883,7 @@ class InputPreprocessor: ...@@ -830,7 +883,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
...@@ -840,17 +893,19 @@ class InputPreprocessor: ...@@ -840,17 +893,19 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder.
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
prompt, tokenization_kwargs) prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
...@@ -861,7 +916,7 @@ class InputPreprocessor: ...@@ -861,7 +916,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
async def preprocess_async( async def preprocess_async(
...@@ -869,19 +924,22 @@ class InputPreprocessor: ...@@ -869,19 +924,22 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Async version of Async version of
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess]. [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
""" """
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder.
return await self._process_encoder_decoder_prompt_async(prompt) return await self._process_encoder_decoder_prompt_async(
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
...@@ -892,5 +950,9 @@ class InputPreprocessor: ...@@ -892,5 +950,9 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, mm_hash_overrides=mm_hash_overrides,
) )
def clear_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()
...@@ -223,20 +223,26 @@ class InputRegistry: ...@@ -223,20 +223,26 @@ class InputRegistry:
The model is identified by ``model_config``. The model is identified by ``model_config``.
""" """
# Avoid circular import # Avoid circular import
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
return DummyData(seq_data=seq_data) return DummyData(seq_data=seq_data)
cache = processor_only_cache_from_config(model_config, mm_registry)
# Encoder dummy data does not contain multi-modal data # Encoder dummy data does not contain multi-modal data
if is_encoder_data: if is_encoder_data:
enc_data = mm_registry.get_encoder_dummy_data( enc_data = mm_registry.get_encoder_dummy_data(model_config,
model_config, seq_len) seq_len,
cache=cache)
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
return DummyData(seq_data=seq_data) return DummyData(seq_data=seq_data)
dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) dec_data = mm_registry.get_decoder_dummy_data(model_config,
seq_len,
cache=cache)
return DummyData( return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
......
...@@ -48,9 +48,6 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: ...@@ -48,9 +48,6 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# GPTQ/AWQ # GPTQ/AWQ
elif hasattr(base_layer, "qweight"): elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device return base_layer.qweight.device
# marlin
elif hasattr(base_layer, "B"):
return base_layer.B.device
# HQQ marlin # HQQ marlin
elif hasattr(base_layer, "W_q"): elif hasattr(base_layer, "W_q"):
return base_layer.W_q.device return base_layer.W_q.device
...@@ -608,7 +605,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -608,7 +605,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices) """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
packed together (eg. gate_proj + up_proj -> gate_up_proj). packed together (e.g. gate_proj + up_proj -> gate_up_proj).
This means we have 2 LoRAs, each applied to one half of the layer. This means we have 2 LoRAs, each applied to one half of the layer.
......
...@@ -207,6 +207,7 @@ class LoRAModel(AdapterModel): ...@@ -207,6 +207,7 @@ class LoRAModel(AdapterModel):
""" """
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
new_embeddings_tensor_path = os.path.join( new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors") lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir, new_embeddings_bin_file_path = os.path.join(lora_dir,
...@@ -255,9 +256,10 @@ class LoRAModel(AdapterModel): ...@@ -255,9 +256,10 @@ class LoRAModel(AdapterModel):
check_unexpected_modules(f) check_unexpected_modules(f)
for module in f.keys(): # noqa for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module) tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path): elif os.path.isfile(lora_bin_file_path) or os.path.isfile(
# When a bin file is provided, we rely on config to find unexpected lora_pt_file_path):
# modules. # When a bin/pt file is provided, we rely on config to find
# unexpected modules.
unexpected_modules = [] unexpected_modules = []
target_modules = peft_helper.target_modules target_modules = peft_helper.target_modules
if not isinstance(target_modules, list): if not isinstance(target_modules, list):
...@@ -279,7 +281,10 @@ class LoRAModel(AdapterModel): ...@@ -279,7 +281,10 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}" f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}." f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct") f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path, lora_file_path = (lora_bin_file_path
if os.path.isfile(lora_bin_file_path) else
lora_pt_file_path)
tensors = torch.load(lora_file_path,
map_location=device, map_location=device,
weights_only=True) weights_only=True)
else: else:
......
...@@ -10,12 +10,15 @@ import torch.nn.functional as F ...@@ -10,12 +10,15 @@ import torch.nn.functional as F
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import LazyDict from vllm.utils import LazyDict
import vllm.envs as envs import vllm.envs as envs
logger = init_logger(__name__)
@CustomOp.register("fatrelu_and_mul") @CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp): class FatreluAndMul(CustomOp):
...@@ -373,6 +376,112 @@ class ReLUSquaredActivation(CustomOp): ...@@ -373,6 +376,112 @@ class ReLUSquaredActivation(CustomOp):
return self.forward_native(x) return self.forward_native(x)
@CustomOp.register("xielu")
class XIELU(CustomOp):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""
def __init__(
self,
alpha_p_init: float = 0.8,
alpha_n_init: float = 0.8,
beta: float = 0.5,
eps: float = -1e-6,
dtype: torch.dtype = torch.bfloat16,
with_vector_loads: bool = False,
):
super().__init__()
self.alpha_p = nn.Parameter(
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) -
1).unsqueeze(0))
self.alpha_n = nn.Parameter(
torch.log(
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) -
1).unsqueeze(0))
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
self.with_vector_loads = with_vector_loads
# Temporary until xIELU CUDA fully implemented
self._beta_scalar = float(self.beta.detach().cpu().float().item())
self._eps_scalar = float(self.eps.detach().cpu().float().item())
self._xielu_cuda_obj = None
try:
import xielu.ops # noqa: F401
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
msg = "Using experimental xIELU CUDA."
try:
from torch._dynamo import allow_in_graph
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
msg += " Enabled torch._dynamo for xIELU CUDA."
except Exception as err:
msg += (f" Could not enable torch._dynamo for xIELU ({err}) - "
"this may result in slower performance.")
self._xielu_cuda_fn = self._xielu_cuda
logger.warning_once(msg)
except Exception as err:
logger.warning_once(
"CUDA-fused xIELU not available (%s) –"
" falling back to a Python version.\n"
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
str(err),
)
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
alpha_p = nn.functional.softplus(self.alpha_p)
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
return torch.where(
x > 0,
alpha_p * x * x + self.beta * x,
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n +
self.beta * x,
)
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""Firewall function to prevent torch.compile from seeing .item()"""
assert self._xielu_cuda_obj is not None, (
"XIELU CUDA object must not be None")
original_shape = x.shape
# CUDA kernel expects 3D tensors, reshape if needed
while x.dim() < 3:
x = x.unsqueeze(0)
if x.dim() > 3:
x = x.view(-1, 1, x.size(-1))
if original_shape != x.shape:
logger.warning_once(
"Warning: xIELU input tensor expects 3 dimensions"
" but got (shape: %s). Reshaping to (shape: %s).",
original_shape,
x.shape,
)
result = self._xielu_cuda_obj.forward(
x,
self.alpha_p,
self.alpha_n,
# Temporary until xIELU CUDA fully implemented ->
# self.{beta,eps}.item()
self._beta_scalar,
self._eps_scalar,
self.with_vector_loads,
)
return result.view(original_shape)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._xielu_cuda_obj is not None and input.is_cuda:
if not torch._dynamo.is_compiling():
return self._xielu_cuda_fn(input)
else:
logger.warning_once(
"torch._dynamo is compiling, using Python version of xIELU."
)
return self._xielu_python(input)
class ScaledActivation(nn.Module): class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters. """An activation function with post-scale parameters.
...@@ -432,12 +541,25 @@ _ACTIVATION_REGISTRY = LazyDict({ ...@@ -432,12 +541,25 @@ _ACTIVATION_REGISTRY = LazyDict({
lambda: nn.SiLU(), lambda: nn.SiLU(),
"quick_gelu": "quick_gelu":
lambda: QuickGELU(), lambda: QuickGELU(),
"tanh":
lambda: nn.Tanh(),
"sigmoid":
lambda: nn.Sigmoid(),
"xielu":
lambda: XIELU(),
}) })
def get_act_fn(act_fn_name: str) -> nn.Module: def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name.""" """Get an activation function by name."""
act_fn_name = act_fn_name.lower() act_fn_name = act_fn_name.lower()
if act_fn_name.startswith("torch.nn.modules."):
activation_name = act_fn_name.split(".")[-1]
if activation_name == "identity":
return nn.Identity()
act_fn_name = activation_name
if act_fn_name not in _ACTIVATION_REGISTRY: if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError( raise ValueError(
f"Activation function {act_fn_name!r} is not supported.") f"Activation function {act_fn_name!r} is not supported.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class AttentionLayerBase(ABC):
"""
Base class for attention-like layers (Attention, Mamba, etc.)
that support the v1 engine.
This provides a common interface for getting attention backends
from different layer types.
"""
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this layer."""
pass
...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
is_blackwell_deep_gemm_e8m0_used) is_deep_gemm_e8m0_used)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm( ...@@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm(
# number of valid tokens for this expert # number of valid tokens for this expert
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
cols = tl.arange(0, BLOCK) cols = tl.arange(0, BLOCK).to(tl.int64)
cols = cols.to(tl.int64) mask = cols < BLOCK
mask_h = cols < BLOCK
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
base_gate_offset = base_input_offset + cols * stride_i_h
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h +
cols * stride_yq_h)
base_ys_offset = e * stride_ys_e + g * stride_ys_g
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
base_i_offset = (e * stride_i_e + t * stride_i_t + gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t,
g * GROUP_SIZE * stride_i_h) mask=mask,
base_yq_offset = (e * stride_yq_e + t * stride_yq_t + other=0.0).to(tl.float32)
g * GROUP_SIZE * stride_yq_h) up = tl.load(input_ptr + base_up_offset + t * stride_i_t,
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
mask = mask_h
x = tl.load(input_ptr + base_i_offset + cols * stride_i_h,
mask=mask,
other=0.0).to(tl.float32)
y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h +
cols * stride_i_h,
mask=mask, mask=mask,
other=0.0).to(tl.float32) other=0.0)
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
y = gate * up
x = x * (1.0 / (1.0 + tl.exp(-x))) y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
y = x * y2 if use_ue8m0:
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(
tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset, y_s) tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
def silu_mul_fp8_quant_deep_gemm( def silu_mul_fp8_quant_deep_gemm(
y: torch.Tensor, # (E, T, 2*H) float32 y: torch.Tensor, # (E, T, 2*H)
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
group_size: int = 128, group_size: int = 128,
eps: float = 1e-10, eps: float = 1e-10,
): ) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8. silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where Returns `(y_q, y_s)` where
* `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`. * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)` * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
""" """
assert y.ndim == 3, "y must be (E, T, 2*H)" assert y.ndim == 3, "y must be (E, T, 2*H)"
E, T, H2 = y.shape E, T, H2 = y.shape
...@@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm( ...@@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm(
stride_cnt_e = tokens_per_expert.stride()[0] stride_cnt_e = tokens_per_expert.stride()[0]
# static grid over experts and H-groups. # Static grid over experts and H-groups.
# A loop inside the kernel handles the token dim # A loop inside the kernel handles the token dim
grid = (E * G, ) grid = (E * G, )
...@@ -176,9 +174,9 @@ def silu_mul_fp8_quant_deep_gemm( ...@@ -176,9 +174,9 @@ def silu_mul_fp8_quant_deep_gemm(
eps, eps,
fp8_min, fp8_min,
fp8_max, fp8_max,
is_blackwell_deep_gemm_e8m0_used(), is_deep_gemm_e8m0_used(),
BLOCK=group_size, BLOCK=group_size,
NUM_STAGES=8, NUM_STAGES=4,
num_warps=1, num_warps=1,
) )
......
...@@ -194,12 +194,6 @@ class FusedMoEParallelConfig: ...@@ -194,12 +194,6 @@ class FusedMoEParallelConfig:
return (self.use_all2all_kernels return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod @staticmethod
def make(tp_size_: int, dp_size_: int, def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
...@@ -408,7 +402,14 @@ class FusedMoEConfig: ...@@ -408,7 +402,14 @@ class FusedMoEConfig:
@property @property
def use_flashinfer_cutlass_kernels(self): def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels """
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return (self.quant_config is not None
and self.quant_config.quant_dtype == "nvfp4"
and envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod @staticmethod
def make( def make(
...@@ -454,6 +455,12 @@ class FusedMoEConfig: ...@@ -454,6 +455,12 @@ class FusedMoEConfig:
if quant_dtype is None and isinstance(quant_config, Fp8Config): if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn quant_dtype = torch.float8_e4m3fn
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Config)
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
quant_dtype = "mxfp8"
from vllm.model_executor.layers.quantization.modelopt import ( from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config) ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config, if quant_dtype is None and isinstance(quant_config,
......
{
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
}
}
{
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
}
}
...@@ -3,10 +3,115 @@ ...@@ -3,10 +3,115 @@
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
from torch.nn import functional as F
from vllm import envs from vllm import envs
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
gating_output = gating_output.float()
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_ids.to(torch.int32)
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
return grouped_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None:
assert scoring_func == "softmax"
topk_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids.to(torch.int32)
else:
return custom_routing_function(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
class IPEXFusedMOE: class IPEXFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None: def __init__(self, layer: torch.nn.Module) -> None:
...@@ -31,12 +136,15 @@ class IPEXFusedMOE: ...@@ -31,12 +136,15 @@ class IPEXFusedMOE:
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported." assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input assert not apply_router_weight_on_input
assert routed_scaling_factor == 1.0, \
f"routed_scaling_factor {routed_scaling_factor} is not supported."
return layer.ipex_fusion( return layer.ipex_fusion(
x, x,
use_grouped_topk, use_grouped_topk,
...@@ -56,113 +164,6 @@ class SGLFusedMOE: ...@@ -56,113 +164,6 @@ class SGLFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None: def __init__(self, layer: torch.nn.Module) -> None:
pass pass
@staticmethod
def _grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
gating_output = gating_output.float()
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use
# biased scores for expert selection but original scores for
# routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores,
k=topk_group,
dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token,
-1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1,
keepdim=True)
return topk_weights, topk_ids.to(torch.int32)
@staticmethod
def _select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = SGLFusedMOE._grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None:
assert scoring_func == "softmax"
topk_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids.to(torch.int32)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids
def __call__( def __call__(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -177,13 +178,14 @@ class SGLFusedMOE: ...@@ -177,13 +178,14 @@ class SGLFusedMOE:
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported." assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input assert not apply_router_weight_on_input
topk_weights, topk_ids = SGLFusedMOE._select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
...@@ -193,6 +195,7 @@ class SGLFusedMOE: ...@@ -193,6 +195,7 @@ class SGLFusedMOE:
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
) )
...@@ -213,3 +216,82 @@ class SGLFusedMOE: ...@@ -213,3 +216,82 @@ class SGLFusedMOE:
True, True,
) )
return x return x
class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
len_experts = global_num_experts
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
layer_w13_weight = layer.w13_weight[i]
layer_w2_weight = layer.w2_weight[i]
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = silu_and_mul(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs,
dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (new_x.view(
*topk_ids.shape, -1).type(topk_weights.dtype).mul_(
topk_weights.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
return final_out
...@@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_unpermute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP) TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_fp8_quantize,
_resize_cache) _resize_cache)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
...@@ -34,6 +35,10 @@ def run_cutlass_moe_fp8( ...@@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
...@@ -41,6 +46,7 @@ def run_cutlass_moe_fp8( ...@@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
use_batched_format: bool, use_batched_format: bool,
topk_weights: Optional[torch.Tensor],
): ):
a1q = hidden_states a1q = hidden_states
...@@ -99,6 +105,22 @@ def run_cutlass_moe_fp8( ...@@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
topk = local_topk_ids.size(1) topk = local_topk_ids.size(1)
local_E = w1.size(0) local_E = w1.size(0)
if use_batched_format:
mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2))
act_out = _resize_cache(workspace2, (local_E * padded_M, N))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(local_E * padded_M, N))
mm2_out = _resize_cache(workspace2, (local_E * padded_M, K))
else:
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
(M * topk, K))
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
act_out = _resize_cache(workspace2, (M * topk, N))
# original workspace are based on input hidden_states dtype (bf16)
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M * topk, N))
mm2_out = _resize_cache(workspace2, (M * topk, K))
if use_batched_format: if use_batched_format:
assert expert_num_tokens is not None assert expert_num_tokens is not None
...@@ -120,11 +142,10 @@ def run_cutlass_moe_fp8( ...@@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
w2_scale = w2_scale.reshape(w2_scale.size(0), -1) w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
a1q = a1q.reshape(-1, a1q.size(2)) a1q = a1q.reshape(-1, a1q.size(2))
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
# c3x get_group_gemm_starts expects int64 to avoid overflow
# during offset calculations
expert_offsets = expert_offsets.to(torch.int64)
else: else:
expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((global_num_experts, 3), problem_sizes1 = torch.empty((global_num_experts, 3),
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
...@@ -132,84 +153,57 @@ def run_cutlass_moe_fp8( ...@@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
# With expert_map each Rank processes only a subset of experts. As num_expert = global_num_experts if expert_map is None \
# a result not all of a_map and c2 tensors are filled. We fill it else expert_map.size(0)
# zeros for correctness. # permuted a1q reuses workspace2
if expert_map is not None: a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
a_map = torch.zeros((local_topk_ids.numel()), a1q,
dtype=torch.int32, a1q_scale,
device=device) topk_ids,
else: num_expert,
a_map = torch.empty((local_topk_ids.numel()), local_E,
dtype=torch.int32, expert_map,
device=device) permuted_hidden_states=a1q_perm)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
expert_offsets = expert_offsets[:-1] expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.size(0), ), ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1,
K, problem_sizes2,
device=device, global_num_experts, N, K)
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
else:
c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
if not per_act_token and (expert_map is not None or use_batched_format): if not per_act_token and (expert_map is not None or use_batched_format):
# this is necessary to avoid imprecise scale calculation caused by # this is necessary to avoid imprecise scale calculation caused by
# random data in the unused workspace. The workspace is unused when # random data in the unused workspace. The workspace is unused when
# this rank handles only partial tokens, or when it is batched . # this rank handles only partial tokens, or when it is batched .
c1.fill_(0) mm1_out.fill_(0)
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1, problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch) per_act_token, per_out_ch)
activation_callable(c2, c1) activation_callable(act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant( a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token) act_out,
a2_scale,
use_per_token_if_dynamic=per_act_token,
output=quant_out)
if expert_map is not None: if expert_map is not None:
c3.fill_(0) mm2_out.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets, ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets,
problem_sizes2, ab_strides2, ab_strides2, c_strides2, problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch) per_act_token, per_out_ch)
if use_batched_format: if use_batched_format:
output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True) output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True)
else: else:
# We can't do this inplace because output may point to the same tensor # for non-chunking mode the output is resized from workspace13
# as c3. # so we need to make sure mm2_out uses workspace2.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) moe_unpermute(out=output,
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm)
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
...@@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
per_out_ch_quant: bool, per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
): ):
super().__init__( super().__init__(
...@@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
block_shape=block_shape, block_shape=block_shape,
)) ))
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. # Let PrepareAndFinalize::finalize() decide the impl.
...@@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8( run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable, output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens, a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype, self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant, self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format) use_batched_format, topk_weights)
class CutlassExpertsFp8(CutlassExpertsFp8Base): class CutlassExpertsFp8(CutlassExpertsFp8Base):
...@@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
per_out_ch_quant: bool, per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
): ):
super().__init__( super().__init__(
out_dtype, out_dtype,
per_act_token_quant, per_act_token_quant,
per_out_ch_quant, per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape, block_shape,
) )
...@@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M * topk, max(N, K)) workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2) workspace2 = (M * topk, max(N // 2, K))
output = (M * topk, K) output = (M, K)
return (workspace1, workspace2, output, return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype) self.out_dtype if self.out_dtype is not None else a.dtype)
...@@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ...@@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
per_out_ch_quant: bool, per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
): ):
super().__init__( super().__init__(
out_dtype, out_dtype,
per_act_token_quant, per_act_token_quant,
per_out_ch_quant, per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape, block_shape,
) )
assert max_experts_per_worker > 0 assert max_experts_per_worker > 0
...@@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ...@@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
assert num_dp is not None assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp, workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K)) max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2)) workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
max(N // 2, K))
output = (self.max_experts_per_worker, padded_M, K) output = (self.max_experts_per_worker, padded_M, K)
return (workspace1, workspace2, output, return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype) self.out_dtype if self.out_dtype is not None else a.dtype)
...@@ -392,6 +416,10 @@ def cutlass_moe_fp8( ...@@ -392,6 +416,10 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None, per_act_token: Optional[bool] = None,
activation: str = "silu", activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
...@@ -419,6 +447,17 @@ def cutlass_moe_fp8( ...@@ -419,6 +447,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N] Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K] Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M] Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
...@@ -450,6 +489,10 @@ def cutlass_moe_fp8( ...@@ -450,6 +489,10 @@ def cutlass_moe_fp8(
out_dtype=a.dtype, out_dtype=a.dtype,
per_act_token_quant=per_act_token, per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch, per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
), ),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment