Commit a99300bd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev

parents cc3e01c7 5438967f
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import json
import os
import sys
import tempfile
......@@ -42,7 +43,6 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
......@@ -99,6 +99,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
......@@ -131,7 +132,9 @@ if TYPE_CHECKING:
VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
......@@ -159,9 +162,12 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
# add envs
VLLM_OPTEST_URLS_PORT: Optional[int] = None
......@@ -188,6 +194,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False
def get_default_cache_root():
......@@ -491,11 +498,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
......@@ -693,11 +695,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR":
lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None),
# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
# Enables torch profiler if set.
# Both AsyncLLM's CPU traces as well as workers'
# traces (CPU & GPU) will be saved under this directory.
# Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
.path.abspath(os.path.expanduser(os.getenv(
"VLLM_TORCH_PROFILER_DIR", ".")))),
# Enable torch profiler to record shapes if set
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will
......@@ -797,6 +802,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP8BMM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in
("true", "1")),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
......@@ -979,9 +990,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
# E8M0 is faster on B200 but may reduce accuracy.
"VLLM_USE_DEEP_GEMM_E8M0":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
# TODO(wentao): unify the two E8M0 flags after verifying the correctness.
# Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs.
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
......@@ -990,6 +1004,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_SKIP_DEEP_GEMM_WARMUP":
lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))),
# Whether to use fused grouped_topk used for MoE expert selection.
"VLLM_USE_FUSED_MOE_GROUPED_TOPK":
lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))),
# Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
......@@ -1068,6 +1086,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
# Specifies the thresholds of the communicated tensor sizes under which
# vllm should use flashinfer fused allreduce. The variable should be a
# JSON with the following format:
# { <world size>: <max size in mb> }
# Unspecified world sizes will fallback to
# { 2: 64, 4: 1, <everything else>: 0.5 }
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB":
lambda: json.loads(os.getenv(
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")),
# MoE routing strategy selector.
# See `RoutingSimulator.get_available_strategies()` # for available
# strategies.
......@@ -1134,6 +1162,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN":
lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
# vllm cutlass GEMM, marlin GEMM.
......@@ -1146,6 +1179,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_CUDAGRAPH_GC":
lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))),
# Disable padding to CUDA graph capture batch sizes.
# TODO(wentao): https://github.com/vllm-project/vllm/issues/23378
# After the issue is fixed, we can remove this flag.
"VLLM_DISABLE_PAD_FOR_CUDAGRAPH":
lambda: bool(int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0"))),
# Used to force set up loopback IP
"VLLM_LOOPBACK_IP":
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
......@@ -1179,6 +1218,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RESPONSES_API_STORE":
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
# Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
......@@ -1294,6 +1337,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")),
# vLLM will use lightop for moe_fused_gate and moe_align_block_size
"VLLM_USE_LIGHT_OP":
lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......@@ -1355,10 +1403,12 @@ def compute_hash() -> str:
"VLLM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ATTENTION_BACKEND",
"VLLM_USE_FLASHINFER_SAMPLER",
"VLLM_FLASHINFER_FORCE_TENSOR_CORES",
"VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM",
"VLLM_USE_DEEP_GEMM_E8M0",
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER",
"VLLM_USE_TRTLLM_FP4_GEMM",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP8",
"VLLM_USE_FLASHINFER_MOE_FP4",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
......@@ -1372,6 +1422,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_RMSNORM",
"VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING",
"VLLM_ROCM_MOE_PADDING",
......
......@@ -101,7 +101,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
result_handler.start()
self.worker_monitor.start()
# Set up signal handlers to shutdown the executor cleanly
# Set up signal handlers to shut down the executor cleanly
# sometimes gc does not work well
self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)
......
......@@ -4,11 +4,12 @@
from array import array
from typing import Any, Type
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types.
"""Custom msgspec enc hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
......@@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any:
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}.")
return obj.tobytes()
if isinstance(obj, MultiModalKwargs):
return dict(obj)
def decode_hook(type: Type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types.
"""Custom msgspec dec hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
......@@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj)
return deserialized
if type is MultiModalKwargs:
return MultiModalKwargs(obj)
......@@ -10,6 +10,7 @@ import msgspec
import vllm.platforms
from vllm.config import ParallelConfig
from vllm.distributed import get_pp_group
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -136,6 +137,11 @@ try:
scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
output = scheduler_output, None
return output
def override_env_vars(self, vars: Dict[str, str]):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
......@@ -18,6 +18,7 @@ target model.
"""
__all__ = [
"DataPrompt",
"TextPrompt",
"TokensPrompt",
"PromptType",
......
......@@ -7,7 +7,8 @@ import torch
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs,
MultiModalUUIDDict)
class TextPrompt(TypedDict):
......@@ -30,6 +31,15 @@ class TextPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them.
"""
multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching, and MUST be unique per
multimodal item.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
......@@ -59,6 +69,14 @@ class TokensPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them.
"""
multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
......@@ -77,6 +95,16 @@ class EmbedsPrompt(TypedDict):
"""
class DataPrompt(TypedDict):
"""Represents generic inputs handled by IO processor plugins."""
data: Any
"""The input data"""
data_format: str
"""The input data format"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:
......@@ -174,9 +202,6 @@ class TokenInputs(TypedDict):
prompt_token_ids: list[int]
"""The token IDs of the prompt."""
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
prompt: NotRequired[str]
"""
The original prompt text corresponding to the token IDs, if available.
......@@ -190,7 +215,6 @@ class TokenInputs(TypedDict):
def token_inputs(
prompt_token_ids: list[int],
token_type_ids: Optional[list[int]] = None,
prompt: Optional[str] = None,
cache_salt: Optional[str] = None,
) -> TokenInputs:
......@@ -200,8 +224,6 @@ def token_inputs(
if prompt is not None:
inputs["prompt"] = prompt
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
......
......@@ -11,8 +11,9 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
MultiModalInputs, MultiModalUUIDDict)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
......@@ -32,12 +33,14 @@ class InputPreprocessor:
model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None:
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
......@@ -257,7 +260,9 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalInputs:
"""
Apply the model's multi-modal processor to a multi-modal prompt,
......@@ -265,17 +270,22 @@ class InputPreprocessor:
"""
tokenizer = self._get_mm_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes)
return mm_processor.apply(
prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
async def _process_multimodal_async(
self,
......@@ -284,7 +294,9 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalInputs:
"""
Async version of
......@@ -292,16 +304,22 @@ class InputPreprocessor:
"""
tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes)
return mm_processor.apply(
prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
def _process_embeds(
self,
......@@ -333,15 +351,33 @@ class InputPreprocessor:
) -> EmbedsInputs:
return self._process_embeds(parsed_content)
def _truncate_inputs(
self,
inputs: list[int],
tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]:
if not tokenization_kwargs or "truncation" not in \
tokenization_kwargs or self.tokenizer is None:
return inputs
max_length = tokenization_kwargs["max_length"]
if self.tokenizer.truncation_side == "left":
return inputs[-max_length:]
else:
return inputs[:max_length]
def _process_tokens(
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
......@@ -351,13 +387,10 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
else:
inputs = token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
inputs = token_inputs(prompt_token_ids=prompt_token_ids)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
......@@ -369,10 +402,12 @@ class InputPreprocessor:
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
......@@ -382,13 +417,10 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
else:
inputs = token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
......@@ -400,7 +432,9 @@ class InputPreprocessor:
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
......@@ -412,7 +446,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
else:
prompt_token_ids = self._tokenize_prompt(
......@@ -435,7 +469,9 @@ class InputPreprocessor:
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
......@@ -447,7 +483,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
......@@ -470,7 +506,9 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> SingletonInputs:
"""
Extract the singleton inputs from a prompt.
......@@ -479,7 +517,6 @@ class InputPreprocessor:
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
* return_mm_hashes: whether to return multimodal hashes
Returns:
......@@ -493,21 +530,21 @@ class InputPreprocessor:
return self._process_tokens(
parsed["content"],
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
assert_never(parsed)
......@@ -517,7 +554,9 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> SingletonInputs:
"""
Async version of
......@@ -531,21 +570,21 @@ class InputPreprocessor:
return await self._process_tokens_async(
parsed["content"],
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
assert_never(parsed)
......@@ -655,6 +694,9 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
......@@ -696,6 +738,7 @@ class InputPreprocessor:
encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
......@@ -711,6 +754,7 @@ class InputPreprocessor:
inputs = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
......@@ -726,6 +770,9 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs:
"""
Async version of
......@@ -738,6 +785,7 @@ class InputPreprocessor:
encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
......@@ -747,6 +795,7 @@ class InputPreprocessor:
decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
encoder_inputs, decoder_inputs = await asyncio.gather(
......@@ -762,6 +811,7 @@ class InputPreprocessor:
inputs = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
......@@ -788,7 +838,9 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs:
"""
For decoder-only models:
......@@ -799,7 +851,6 @@ class InputPreprocessor:
* prompt: input prompt
* lora_request
* return_mm_hashes
Returns:
......@@ -810,7 +861,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
......@@ -820,7 +871,9 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs:
"""
Async version of
......@@ -830,7 +883,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
......@@ -840,17 +893,19 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
# input prompts to encoder & decoder.
return self._process_encoder_decoder_prompt(
prompt, tokenization_kwargs)
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
......@@ -861,7 +916,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
async def preprocess_async(
......@@ -869,19 +924,22 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> ProcessorInputs:
"""
Async version of
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(prompt)
# input prompts to encoder & decoder.
return await self._process_encoder_decoder_prompt_async(
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
......@@ -892,5 +950,9 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
mm_hash_overrides=mm_hash_overrides,
)
def clear_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()
......@@ -223,20 +223,26 @@ class InputRegistry:
The model is identified by ``model_config``.
"""
# Avoid circular import
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.sequence import SequenceData
if not model_config.is_multimodal_model:
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
return DummyData(seq_data=seq_data)
cache = processor_only_cache_from_config(model_config, mm_registry)
# Encoder dummy data does not contain multi-modal data
if is_encoder_data:
enc_data = mm_registry.get_encoder_dummy_data(
model_config, seq_len)
enc_data = mm_registry.get_encoder_dummy_data(model_config,
seq_len,
cache=cache)
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
return DummyData(seq_data=seq_data)
dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
dec_data = mm_registry.get_decoder_dummy_data(model_config,
seq_len,
cache=cache)
return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
......
......@@ -48,9 +48,6 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# GPTQ/AWQ
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
# marlin
elif hasattr(base_layer, "B"):
return base_layer.B.device
# HQQ marlin
elif hasattr(base_layer, "W_q"):
return base_layer.W_q.device
......@@ -608,7 +605,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
packed together (eg. gate_proj + up_proj -> gate_up_proj).
packed together (e.g. gate_proj + up_proj -> gate_up_proj).
This means we have 2 LoRAs, each applied to one half of the layer.
......
......@@ -207,6 +207,7 @@ class LoRAModel(AdapterModel):
"""
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
......@@ -255,9 +256,10 @@ class LoRAModel(AdapterModel):
check_unexpected_modules(f)
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path):
# When a bin file is provided, we rely on config to find unexpected
# modules.
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(
lora_pt_file_path):
# When a bin/pt file is provided, we rely on config to find
# unexpected modules.
unexpected_modules = []
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
......@@ -279,7 +281,10 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path,
lora_file_path = (lora_bin_file_path
if os.path.isfile(lora_bin_file_path) else
lora_pt_file_path)
tensors = torch.load(lora_file_path,
map_location=device,
weights_only=True)
else:
......
......@@ -10,12 +10,15 @@ import torch.nn.functional as F
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import LazyDict
import vllm.envs as envs
logger = init_logger(__name__)
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
......@@ -373,6 +376,112 @@ class ReLUSquaredActivation(CustomOp):
return self.forward_native(x)
@CustomOp.register("xielu")
class XIELU(CustomOp):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""
def __init__(
self,
alpha_p_init: float = 0.8,
alpha_n_init: float = 0.8,
beta: float = 0.5,
eps: float = -1e-6,
dtype: torch.dtype = torch.bfloat16,
with_vector_loads: bool = False,
):
super().__init__()
self.alpha_p = nn.Parameter(
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) -
1).unsqueeze(0))
self.alpha_n = nn.Parameter(
torch.log(
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) -
1).unsqueeze(0))
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
self.with_vector_loads = with_vector_loads
# Temporary until xIELU CUDA fully implemented
self._beta_scalar = float(self.beta.detach().cpu().float().item())
self._eps_scalar = float(self.eps.detach().cpu().float().item())
self._xielu_cuda_obj = None
try:
import xielu.ops # noqa: F401
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
msg = "Using experimental xIELU CUDA."
try:
from torch._dynamo import allow_in_graph
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
msg += " Enabled torch._dynamo for xIELU CUDA."
except Exception as err:
msg += (f" Could not enable torch._dynamo for xIELU ({err}) - "
"this may result in slower performance.")
self._xielu_cuda_fn = self._xielu_cuda
logger.warning_once(msg)
except Exception as err:
logger.warning_once(
"CUDA-fused xIELU not available (%s) –"
" falling back to a Python version.\n"
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
str(err),
)
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
alpha_p = nn.functional.softplus(self.alpha_p)
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
return torch.where(
x > 0,
alpha_p * x * x + self.beta * x,
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n +
self.beta * x,
)
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""Firewall function to prevent torch.compile from seeing .item()"""
assert self._xielu_cuda_obj is not None, (
"XIELU CUDA object must not be None")
original_shape = x.shape
# CUDA kernel expects 3D tensors, reshape if needed
while x.dim() < 3:
x = x.unsqueeze(0)
if x.dim() > 3:
x = x.view(-1, 1, x.size(-1))
if original_shape != x.shape:
logger.warning_once(
"Warning: xIELU input tensor expects 3 dimensions"
" but got (shape: %s). Reshaping to (shape: %s).",
original_shape,
x.shape,
)
result = self._xielu_cuda_obj.forward(
x,
self.alpha_p,
self.alpha_n,
# Temporary until xIELU CUDA fully implemented ->
# self.{beta,eps}.item()
self._beta_scalar,
self._eps_scalar,
self.with_vector_loads,
)
return result.view(original_shape)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._xielu_cuda_obj is not None and input.is_cuda:
if not torch._dynamo.is_compiling():
return self._xielu_cuda_fn(input)
else:
logger.warning_once(
"torch._dynamo is compiling, using Python version of xIELU."
)
return self._xielu_python(input)
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
......@@ -432,12 +541,25 @@ _ACTIVATION_REGISTRY = LazyDict({
lambda: nn.SiLU(),
"quick_gelu":
lambda: QuickGELU(),
"tanh":
lambda: nn.Tanh(),
"sigmoid":
lambda: nn.Sigmoid(),
"xielu":
lambda: XIELU(),
})
def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name.startswith("torch.nn.modules."):
activation_name = act_fn_name.split(".")[-1]
if activation_name == "identity":
return nn.Identity()
act_fn_name = activation_name
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class AttentionLayerBase(ABC):
"""
Base class for attention-like layers (Attention, Mamba, etc.)
that support the v1 engine.
This provides a common interface for getting attention backends
from different layer types.
"""
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this layer."""
pass
......@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
is_blackwell_deep_gemm_e8m0_used)
is_deep_gemm_e8m0_used)
logger = init_logger(__name__)
......@@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm(
# number of valid tokens for this expert
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
cols = tl.arange(0, BLOCK)
cols = cols.to(tl.int64)
mask_h = cols < BLOCK
cols = tl.arange(0, BLOCK).to(tl.int64)
mask = cols < BLOCK
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
base_gate_offset = base_input_offset + cols * stride_i_h
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h +
cols * stride_yq_h)
base_ys_offset = e * stride_ys_e + g * stride_ys_g
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
base_i_offset = (e * stride_i_e + t * stride_i_t +
g * GROUP_SIZE * stride_i_h)
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
g * GROUP_SIZE * stride_yq_h)
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
mask = mask_h
x = tl.load(input_ptr + base_i_offset + cols * stride_i_h,
mask=mask,
other=0.0).to(tl.float32)
y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h +
cols * stride_i_h,
gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t,
mask=mask,
other=0.0).to(tl.float32)
up = tl.load(input_ptr + base_up_offset + t * stride_i_t,
mask=mask,
other=0.0).to(tl.float32)
other=0.0)
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
y = gate * up
x = x * (1.0 / (1.0 + tl.exp(-x)))
y = x * y2
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
if use_ue8m0:
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(
tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset, y_s)
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
def silu_mul_fp8_quant_deep_gemm(
y: torch.Tensor, # (E, T, 2*H) float32
y: torch.Tensor, # (E, T, 2*H)
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
group_size: int = 128,
eps: float = 1e-10,
):
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
* `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
* `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
"""
assert y.ndim == 3, "y must be (E, T, 2*H)"
E, T, H2 = y.shape
......@@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm(
stride_cnt_e = tokens_per_expert.stride()[0]
# static grid over experts and H-groups.
# Static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid = (E * G, )
......@@ -176,9 +174,9 @@ def silu_mul_fp8_quant_deep_gemm(
eps,
fp8_min,
fp8_max,
is_blackwell_deep_gemm_e8m0_used(),
is_deep_gemm_e8m0_used(),
BLOCK=group_size,
NUM_STAGES=8,
NUM_STAGES=4,
num_warps=1,
)
......
......@@ -194,12 +194,6 @@ class FusedMoEParallelConfig:
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
......@@ -408,7 +402,14 @@ class FusedMoEConfig:
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
"""
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return (self.quant_config is not None
and self.quant_config.quant_dtype == "nvfp4"
and envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod
def make(
......@@ -454,6 +455,12 @@ class FusedMoEConfig:
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Config)
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
quant_dtype = "mxfp8"
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
......
{
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
}
}
{
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
}
}
......@@ -3,10 +3,115 @@
from typing import Callable, Optional
import torch
from torch.nn import functional as F
from vllm import envs
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
gating_output = gating_output.float()
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_ids.to(torch.int32)
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
return grouped_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None:
assert scoring_func == "softmax"
topk_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids.to(torch.int32)
else:
return custom_routing_function(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
class IPEXFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
......@@ -31,12 +136,15 @@ class IPEXFusedMOE:
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
assert routed_scaling_factor == 1.0, \
f"routed_scaling_factor {routed_scaling_factor} is not supported."
return layer.ipex_fusion(
x,
use_grouped_topk,
......@@ -56,113 +164,6 @@ class SGLFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
@staticmethod
def _grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
gating_output = gating_output.float()
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use
# biased scores for expert selection but original scores for
# routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores,
k=topk_group,
dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token,
-1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1,
keepdim=True)
return topk_weights, topk_ids.to(torch.int32)
@staticmethod
def _select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = SGLFusedMOE._grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None:
assert scoring_func == "softmax"
topk_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids.to(torch.int32)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids
def __call__(
self,
layer: torch.nn.Module,
......@@ -177,13 +178,14 @@ class SGLFusedMOE:
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = SGLFusedMOE._select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -193,6 +195,7 @@ class SGLFusedMOE:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
......@@ -213,3 +216,82 @@ class SGLFusedMOE:
True,
)
return x
class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
len_experts = global_num_experts
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
layer_w13_weight = layer.w13_weight[i]
layer_w2_weight = layer.w2_weight[i]
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = silu_and_mul(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs,
dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (new_x.view(
*topk_ids.shape, -1).type(topk_weights.dtype).mul_(
topk_weights.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
return final_out
......@@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_unpermute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
from vllm.scalar_type import scalar_types
......@@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
......@@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
per_act_token: bool,
per_out_ch: bool,
use_batched_format: bool,
topk_weights: Optional[torch.Tensor],
):
a1q = hidden_states
......@@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
topk = local_topk_ids.size(1)
local_E = w1.size(0)
if use_batched_format:
mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2))
act_out = _resize_cache(workspace2, (local_E * padded_M, N))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(local_E * padded_M, N))
mm2_out = _resize_cache(workspace2, (local_E * padded_M, K))
else:
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
(M * topk, K))
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
act_out = _resize_cache(workspace2, (M * topk, N))
# original workspace are based on input hidden_states dtype (bf16)
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M * topk, N))
mm2_out = _resize_cache(workspace2, (M * topk, K))
if use_batched_format:
assert expert_num_tokens is not None
......@@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
a1q = a1q.reshape(-1, a1q.size(2))
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
# c3x get_group_gemm_starts expects int64 to avoid overflow
# during offset calculations
expert_offsets = expert_offsets.to(torch.int64)
else:
expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
......@@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
dtype=torch.int32,
device=device)
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
if expert_map is not None:
a_map = torch.zeros((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
else:
a_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
num_expert = global_num_experts if expert_map is None \
else expert_map.size(0)
# permuted a1q reuses workspace2
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
a1q,
a1q_scale,
topk_ids,
num_expert,
local_E,
expert_map,
permuted_hidden_states=a1q_perm)
expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
else:
c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1,
problem_sizes2,
global_num_experts, N, K)
if not per_act_token and (expert_map is not None or use_batched_format):
# this is necessary to avoid imprecise scale calculation caused by
# random data in the unused workspace. The workspace is unused when
# this rank handles only partial tokens, or when it is batched .
c1.fill_(0)
mm1_out.fill_(0)
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch)
activation_callable(c2, c1)
activation_callable(act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
act_out,
a2_scale,
use_per_token_if_dynamic=per_act_token,
output=quant_out)
if expert_map is not None:
c3.fill_(0)
mm2_out.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets,
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch)
if use_batched_format:
output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True)
output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True)
else:
# We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
# for non-chunking mode the output is resized from workspace13
# so we need to make sure mm2_out uses workspace2.
moe_unpermute(out=output,
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm)
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
......@@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
):
super().__init__(
......@@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
block_shape=block_shape,
))
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
......@@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format)
use_batched_format, topk_weights)
class CutlassExpertsFp8(CutlassExpertsFp8Base):
......@@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
)
......@@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
a: torch.Tensor,
......@@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2)
output = (M * topk, K)
workspace2 = (M * topk, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
......@@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
)
assert max_experts_per_worker > 0
......@@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
max(N // 2, K))
output = (self.max_experts_per_worker, padded_M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
......@@ -392,6 +416,10 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
......@@ -419,6 +447,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
......@@ -450,6 +489,10 @@ def cutlass_moe_fp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment