Commit 31330101 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-dev

parents e8933c34 dc1b4a6f
......@@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
......@@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._items_by_modality = defaultdict[str, list[_T]](list)
......@@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@property
def mm_registry(self):
return MULTIMODAL_REGISTRY
@staticmethod
@cache
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
......@@ -487,8 +489,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)"
if model_type in ("blip-2", "fuyu", "paligemma", "pixtral",
"mistral3"):
if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
"pixtral", "mistral3"):
# These models do not use image tokens in the prompt
return None
if model_type == "qwen":
......@@ -498,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config.image_token_index)
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
"internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat"):
"h2ovl_chat", "idefics3", "smolvlm"):
return "<image>"
if model_type in ("mllama", "llama4"):
return "<|image|>"
......@@ -506,8 +508,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "molmo":
return ""
if model_type == "idefics3":
return "<image>"
if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>"
if model_type == "gemma3":
......@@ -542,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
mm_registry = self.mm_registry
model_config = self.model_config
input_modality = modality.replace("_embeds", "")
if mm_registry.has_processor(model_config):
mm_processor = mm_registry.create_processor(model_config)
allowed_counts = mm_processor.info.get_allowed_mm_limits()
allowed_count = allowed_counts.get(input_modality, 0)
else:
mm_config = model_config.multimodal_config
if mm_config is None:
msg = "This model does not support multi-modal inputs"
raise ValueError(msg)
allowed_count = mm_config.get_limit_per_prompt(input_modality)
current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
"one request. You can set `--limit-mm-per-prompt` to "
"increase this limit if the model supports it.")
self._items_by_modality[modality].append(item)
......@@ -874,19 +891,19 @@ MM_PARSER_MAP: dict[
Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
"text":
lambda part: _TextParser(part).get("text", ""),
lambda part: _TextParser(part).get("text", None),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}),
lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"input_audio":
lambda part: _InputAudioParser(part).get("input_audio", {}),
lambda part: _InputAudioParser(part).get("input_audio", None),
"refusal":
lambda part: _RefusalParser(part).get("refusal", ""),
lambda part: _RefusalParser(part).get("refusal", None),
"video_url":
lambda part: _VideoParser(part).get("video_url", {}).get("url", ""),
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
}
......@@ -1005,11 +1022,11 @@ def _parse_chat_message_content_part(
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
# content is None, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
logger.warning(
"Skipping multimodal part (type: '%s') "
"with empty / unparsable content.", part_type)
"Skipping multimodal part '%s' (type: '%s') "
"with empty / unparsable content.", part, part_type)
return None
if part_type in ("text", "refusal"):
......@@ -1195,8 +1212,15 @@ def apply_mistral_chat_template(
**kwargs,
)
return tokenizer.apply_chat_template(
messages=messages,
tools=tools,
**kwargs,
)
try:
return tokenizer.apply_chat_template(
messages=messages,
tools=tools,
**kwargs,
)
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# are properly caught in the preprocessing_input step
except AssertionError as e:
raise ValueError from e
......@@ -32,6 +32,7 @@ class BenchmarkSubcommandBase(CLISubcommand):
parser = subparsers.add_parser(
self.name,
help=self.help,
description=self.help,
usage=f"vllm bench {self.name} [options]")
self.add_cli_args(parser)
return parser
......@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand):
bench_parser = subparsers.add_parser(
"bench",
help="vLLM bench subcommand.",
description="vLLM bench subcommand.",
usage="vllm bench <bench_type> [options]")
bench_subparsers = bench_parser.add_subparsers(required=True,
dest="bench_type")
......
......@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand):
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
chat_parser = subparsers.add_parser(
"chat",
help="Generate chat completions via the running API server",
help="Generate chat completions via the running API server.",
description="Generate chat completions via the running API server.",
usage="vllm chat [options]")
_add_query_options(chat_parser)
chat_parser.add_argument(
......@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand):
complete_parser = subparsers.add_parser(
"complete",
help=("Generate text completions based on the given prompt "
"via the running API server"),
"via the running API server."),
description=("Generate text completions based on the given prompt "
"via the running API server."),
usage="vllm complete [options]")
_add_query_options(complete_parser)
return complete_parser
......
......@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand):
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser(
"serve",
help="Start the vLLM OpenAI Compatible API server",
help="Start the vLLM OpenAI Compatible API server.",
description="Start the vLLM OpenAI Compatible API server.",
usage="vllm serve [model_tag] [options]")
serve_parser.add_argument("model_tag",
type=str,
......
......@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
import cloudpickle
import torch.nn as nn
from tqdm import tqdm
from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
......@@ -117,6 +117,9 @@ class LLM:
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
......@@ -177,6 +180,7 @@ class LLM:
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
......@@ -232,6 +236,7 @@ class LLM:
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
hf_token=hf_token,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
......@@ -531,6 +536,16 @@ class LLM:
tokenizer.eos_token_id,
length_penalty)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if any("multi_modal_data" in prompt
and prompt["multi_modal_data"] is not None
for prompt in prompts):
logger.warning(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!")
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
......@@ -906,6 +921,11 @@ class LLM:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
elif isinstance(pooling_params, PoolingParams):
pooling_params.verify(self.llm_engine.model_config)
else:
for pooling_param in pooling_params:
pooling_param.verify(self.llm_engine.model_config)
self._validate_and_add_requests(
prompts=parsed_prompts,
......@@ -924,6 +944,8 @@ class LLM:
/,
*,
use_tqdm: bool = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[EmbeddingRequestOutput]:
......@@ -938,6 +960,8 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
......@@ -953,6 +977,7 @@ class LLM:
items = self.encode(prompts,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
......
......@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"
guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
......@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# doc: end-embedding-extra-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
class EmbeddingChatRequest(OpenAIBaseModel):
......@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return data
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
......
......@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
truncate_tool_call_ids)
truncate_tool_call_ids,
validate_request_params)
logger = init_logger(__name__)
......@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing):
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request)
validate_request_params(request)
if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None)
......
......@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model)
request_id = f"embd-{self._base_request_id(raw_request)}"
......@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len."
" Please, select a smaller truncation size.")
pooling_params = request.to_pooling_params()
try:
pooling_params.verify(self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
try:
(
lora_request,
......@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
......
......@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception):
class PythonicToolParser(ToolParser):
"""
Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 models.
such as Llama 3.2 and Llama 4 models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
"""
......
......@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]:
# partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON
# JSONDecoder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
......
......@@ -23,6 +23,7 @@ if TYPE_CHECKING:
VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_SPEC_DECODE_EAGER: bool = False
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
......@@ -115,6 +116,7 @@ if TYPE_CHECKING:
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
def get_default_cache_root():
......@@ -289,6 +291,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION":
lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)),
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_ENFORCE_EAGER_BS_THRESHOLD":
lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
......@@ -716,6 +722,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
# Use model_redirect to redirect the model name to a local folder.
# `model_redirect` can be a json file mapping the model between
# repo_id and local folder:
# {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"}
# or a space separated values table file:
# meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B
"VLLM_MODEL_REDIRECT_PATH":
lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None),
......@@ -743,6 +754,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB":
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
}
# end-env-vars-definition
......
......@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
and len(packed_modules_list) == 3)
#TODO: Implement this
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
pass
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
......
......@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager):
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
self.adapter_type = 'LoRa'
self.adapter_type = 'LoRA'
@property
def capacity(self) -> int:
......
......@@ -111,7 +111,7 @@ class LoRAKernelMeta:
# active_lora_ids, num_tokens_per_lora
lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping,
sorted=False,
sorted=True,
return_counts=True)
self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids,
non_blocking=True)
......
......@@ -33,6 +33,12 @@ def maybe_backend_fallback(
logger.warning("%s Falling back to use %s instead.", message, fallback)
guided_params.backend = fallback
# `auto` was added for V1 to explicitly declare a mode that has fallbacks
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
# fallbacks enabled for that and it is the V0 default.
if guided_params.backend == "auto":
guided_params.backend = "xgrammar"
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend_name == "lm-format-enforcer":
if guided_params.grammar is not None:
......@@ -53,14 +59,9 @@ def maybe_backend_fallback(
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed)
# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
fallback_or_error(
guided_params,
"xgrammar does not support regex guided decoding.", "outlines")
# xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None
and has_xgrammar_unsupported_json_features(guided_params.json)):
if (guided_params.json is not None and
has_xgrammar_unsupported_json_features(guided_params.json)):
fallback_or_error(
guided_params,
"xgrammar does not support advanced JSON schema features like "
......
......@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if "pattern" in obj:
return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
......
......@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List
import torch
import vllm.envs
from vllm.logger import init_logger
try:
......@@ -131,8 +132,13 @@ class GrammarCompilerCache:
encoded_vocab=config_data.encoded_vocab,
metadata=config_data.metadata,
)
cache_size = vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024
cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=config.max_threads)
tokenizer_info,
max_threads=config.max_threads,
cache_enabled=True,
cache_limit_bytes=cache_size,
)
return cls._cache[cache_key]
......@@ -146,6 +152,7 @@ class GrammarConfig:
grammar_str: str | None = None
json_object: bool | None = None
any_whitespace: bool = True
regex_str: str | None = None
max_threads: int = 8
@classmethod
......@@ -249,6 +256,13 @@ class GrammarConfig:
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
elif guided_params.regex:
return cls(
regex_str=guided_params.regex,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
else:
raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar"
......@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor:
self.ctx = compiler\
.compile_json_schema('{"type": "object"}',
any_whitespace=any_whitespace)
elif self.config.regex_str:
self.ctx = compiler.compile_regex(self.config.regex_str)
else:
raise ValueError(
"Invalid configuration for xgrammar logits processor")
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
......@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
......@@ -479,51 +482,53 @@ def fused_moe_kernel_gptq_awq(
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr):
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
......@@ -605,11 +610,22 @@ def fused_moe_kernel(
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8 or use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse +
offs_bsn * stride_bsn)
# channel-wise
elif per_channel_quant:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:,
None]
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
......@@ -645,7 +661,11 @@ def fused_moe_kernel(
accumulator += tl.dot(a, b) * a_scale[:,
None] * b_scale[None, :]
else:
accumulator = tl.dot(a, b, acc=accumulator)
if use_fp8_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
......@@ -693,33 +713,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.shape[0]
num_tokens = M * top_k
......@@ -887,7 +887,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
#BLOCK_SIZE_K=BLOCK_SIZE_K,
per_channel_quant=per_channel_quant,
# BLOCK_SIZE_K=BLOCK_SIZE_K,
**config,
)
......@@ -1263,6 +1264,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1275,9 +1277,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, use_nn_moe)
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
def inplace_fused_experts_fake(
......@@ -1292,6 +1295,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1310,6 +1314,7 @@ direct_register_custom_op(
op_func=inplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
......@@ -1325,6 +1330,7 @@ def outplace_fused_experts(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1337,7 +1343,8 @@ def outplace_fused_experts(
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, per_channel_quant,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
......@@ -1354,6 +1361,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1372,6 +1380,7 @@ direct_register_custom_op(
op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
......@@ -1405,6 +1414,7 @@ def fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1448,6 +1458,7 @@ def fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
......@@ -1460,6 +1471,59 @@ def fused_experts(hidden_states: torch.Tensor,
use_nn_moe=use_nn_moe)
def moe_kernel_prepare_input(
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
# activation channel-wise int8 quantization
assert (per_channel_quant
), "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
return A, A_scale
def fused_experts_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
......@@ -1472,6 +1536,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1583,15 +1648,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qcurr_hidden_states, a1q_scale = _fp8_quantize(
curr_hidden_states, a1_scale, block_shape)
else:
qcurr_hidden_states = curr_hidden_states
a1q_scale = a1_scale
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
......@@ -1620,6 +1676,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
"num_warps": 4
}
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
A=curr_hidden_states,
B=w1,
A_scale=a1_scale,
B_scale=w1_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
if use_int4_w4a16:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
......@@ -1632,7 +1700,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel(qcurr_hidden_states,
w1,
intermediate_cache1,
a1q_scale,
qa1_scale,
w1_scale,
w1_zp,
curr_topk_weights,
......@@ -1647,6 +1715,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
......@@ -1658,15 +1727,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qintermediate_cache2, a2q_scale = _fp8_quantize(
intermediate_cache2, a2_scale, block_shape)
else:
qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale
qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
A=intermediate_cache2,
B=w2,
A_scale=a2_scale,
B_scale=w2_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
......@@ -1698,7 +1770,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
qa2_scale,
w2_scale,
w2_zp,
curr_topk_weights,
......@@ -1713,6 +1785,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
......@@ -1739,6 +1812,7 @@ def fused_moe(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1772,6 +1846,8 @@ def fused_moe(
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
......@@ -1821,6 +1897,7 @@ def fused_moe(
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment