Commit 31330101 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-dev

parents e8933c34 dc1b4a6f
...@@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict ...@@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.utils import MediaConnector from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
...@@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config self._model_config = model_config
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._items_by_modality = defaultdict[str, list[_T]](list) self._items_by_modality = defaultdict[str, list[_T]](list)
...@@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def allowed_local_media_path(self): def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path return self._model_config.allowed_local_media_path
@property
def mm_registry(self):
return MULTIMODAL_REGISTRY
@staticmethod @staticmethod
@cache @cache
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
...@@ -487,8 +489,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -487,8 +489,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|endoftext10|>" # 200010 (see vocab.json in hf model) return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"): if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)" return "(<image>./</image>)"
if model_type in ("blip-2", "fuyu", "paligemma", "pixtral", if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
"mistral3"): "pixtral", "mistral3"):
# These models do not use image tokens in the prompt # These models do not use image tokens in the prompt
return None return None
if model_type == "qwen": if model_type == "qwen":
...@@ -498,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -498,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config.image_token_index) hf_config.image_token_index)
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2", if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
"internvl_chat", "skywork_chat", "NVLM_D", "internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat"): "h2ovl_chat", "idefics3", "smolvlm"):
return "<image>" return "<image>"
if model_type in ("mllama", "llama4"): if model_type in ("mllama", "llama4"):
return "<|image|>" return "<|image|>"
...@@ -506,8 +508,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -506,8 +508,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|vision_start|><|image_pad|><|vision_end|>" return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "molmo": if model_type == "molmo":
return "" return ""
if model_type == "idefics3":
return "<image>"
if model_type == "aria": if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>" return "<|fim_prefix|><|img|><|fim_suffix|>"
if model_type == "gemma3": if model_type == "gemma3":
...@@ -542,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -542,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
Add a multi-modal item to the current prompt and returns the Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any. placeholder string to use, if any.
""" """
allowed_count = self._allowed_items.get(modality, 1) mm_registry = self.mm_registry
model_config = self.model_config
input_modality = modality.replace("_embeds", "")
if mm_registry.has_processor(model_config):
mm_processor = mm_registry.create_processor(model_config)
allowed_counts = mm_processor.info.get_allowed_mm_limits()
allowed_count = allowed_counts.get(input_modality, 0)
else:
mm_config = model_config.multimodal_config
if mm_config is None:
msg = "This model does not support multi-modal inputs"
raise ValueError(msg)
allowed_count = mm_config.get_limit_per_prompt(input_modality)
current_count = len(self._items_by_modality[modality]) + 1 current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count: if current_count > allowed_count:
raise ValueError( raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in " f"At most {allowed_count} {modality}(s) may be provided in "
"one request.") "one request. You can set `--limit-mm-per-prompt` to "
"increase this limit if the model supports it.")
self._items_by_modality[modality].append(item) self._items_by_modality[modality].append(item)
...@@ -874,19 +891,19 @@ MM_PARSER_MAP: dict[ ...@@ -874,19 +891,19 @@ MM_PARSER_MAP: dict[
Callable[[ChatCompletionContentPartParam], _ContentPart], Callable[[ChatCompletionContentPartParam], _ContentPart],
] = { ] = {
"text": "text":
lambda part: _TextParser(part).get("text", ""), lambda part: _TextParser(part).get("text", None),
"image_url": "image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""), lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds": "image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}), lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
"audio_url": "audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""), lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"input_audio": "input_audio":
lambda part: _InputAudioParser(part).get("input_audio", {}), lambda part: _InputAudioParser(part).get("input_audio", None),
"refusal": "refusal":
lambda part: _RefusalParser(part).get("refusal", ""), lambda part: _RefusalParser(part).get("refusal", None),
"video_url": "video_url":
lambda part: _VideoParser(part).get("video_url", {}).get("url", ""), lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
} }
...@@ -1005,11 +1022,11 @@ def _parse_chat_message_content_part( ...@@ -1005,11 +1022,11 @@ def _parse_chat_message_content_part(
part_type, content = _parse_chat_message_content_mm_part(part) part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# content is empty, log a warning and skip # content is None, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
logger.warning( logger.warning(
"Skipping multimodal part (type: '%s') " "Skipping multimodal part '%s' (type: '%s') "
"with empty / unparsable content.", part_type) "with empty / unparsable content.", part, part_type)
return None return None
if part_type in ("text", "refusal"): if part_type in ("text", "refusal"):
...@@ -1195,8 +1212,15 @@ def apply_mistral_chat_template( ...@@ -1195,8 +1212,15 @@ def apply_mistral_chat_template(
**kwargs, **kwargs,
) )
return tokenizer.apply_chat_template( try:
messages=messages, return tokenizer.apply_chat_template(
tools=tools, messages=messages,
**kwargs, tools=tools,
) **kwargs,
)
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# are properly caught in the preprocessing_input step
except AssertionError as e:
raise ValueError from e
...@@ -32,6 +32,7 @@ class BenchmarkSubcommandBase(CLISubcommand): ...@@ -32,6 +32,7 @@ class BenchmarkSubcommandBase(CLISubcommand):
parser = subparsers.add_parser( parser = subparsers.add_parser(
self.name, self.name,
help=self.help, help=self.help,
description=self.help,
usage=f"vllm bench {self.name} [options]") usage=f"vllm bench {self.name} [options]")
self.add_cli_args(parser) self.add_cli_args(parser)
return parser return parser
...@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand): ...@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand):
bench_parser = subparsers.add_parser( bench_parser = subparsers.add_parser(
"bench", "bench",
help="vLLM bench subcommand.", help="vLLM bench subcommand.",
description="vLLM bench subcommand.",
usage="vllm bench <bench_type> [options]") usage="vllm bench <bench_type> [options]")
bench_subparsers = bench_parser.add_subparsers(required=True, bench_subparsers = bench_parser.add_subparsers(required=True,
dest="bench_type") dest="bench_type")
......
...@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand): ...@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand):
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
chat_parser = subparsers.add_parser( chat_parser = subparsers.add_parser(
"chat", "chat",
help="Generate chat completions via the running API server", help="Generate chat completions via the running API server.",
description="Generate chat completions via the running API server.",
usage="vllm chat [options]") usage="vllm chat [options]")
_add_query_options(chat_parser) _add_query_options(chat_parser)
chat_parser.add_argument( chat_parser.add_argument(
...@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand): ...@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand):
complete_parser = subparsers.add_parser( complete_parser = subparsers.add_parser(
"complete", "complete",
help=("Generate text completions based on the given prompt " help=("Generate text completions based on the given prompt "
"via the running API server"), "via the running API server."),
description=("Generate text completions based on the given prompt "
"via the running API server."),
usage="vllm complete [options]") usage="vllm complete [options]")
_add_query_options(complete_parser) _add_query_options(complete_parser)
return complete_parser return complete_parser
......
...@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand): ...@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand):
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser( serve_parser = subparsers.add_parser(
"serve", "serve",
help="Start the vLLM OpenAI Compatible API server", help="Start the vLLM OpenAI Compatible API server.",
description="Start the vLLM OpenAI Compatible API server.",
usage="vllm serve [model_tag] [options]") usage="vllm serve [model_tag] [options]")
serve_parser.add_argument("model_tag", serve_parser.add_argument("model_tag",
type=str, type=str,
......
...@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload ...@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
import cloudpickle import cloudpickle
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
...@@ -117,6 +117,9 @@ class LLM: ...@@ -117,6 +117,9 @@ class LLM:
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig` disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing. disable_async_output_proc: Disable async output processing.
This may result in lower performance. This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the HuggingFace config. If a callable, it is called to update the
HuggingFace config. HuggingFace config.
...@@ -177,6 +180,7 @@ class LLM: ...@@ -177,6 +180,7 @@ class LLM:
max_seq_len_to_capture: int = 8192, max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None, hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None,
# After positional args are removed, move this right below `model` # After positional args are removed, move this right below `model`
...@@ -232,6 +236,7 @@ class LLM: ...@@ -232,6 +236,7 @@ class LLM:
max_seq_len_to_capture=max_seq_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
hf_token=hf_token,
hf_overrides=hf_overrides, hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config, override_pooler_config=override_pooler_config,
...@@ -531,6 +536,16 @@ class LLM: ...@@ -531,6 +536,16 @@ class LLM:
tokenizer.eos_token_id, tokenizer.eos_token_id,
length_penalty) length_penalty)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if any("multi_modal_data" in prompt
and prompt["multi_modal_data"] is not None
for prompt in prompts):
logger.warning(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!")
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step # generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation # following the huggingface transformers implementation
...@@ -906,6 +921,11 @@ class LLM: ...@@ -906,6 +921,11 @@ class LLM:
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
elif isinstance(pooling_params, PoolingParams):
pooling_params.verify(self.llm_engine.model_config)
else:
for pooling_param in pooling_params:
pooling_param.verify(self.llm_engine.model_config)
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, prompts=parsed_prompts,
...@@ -924,6 +944,8 @@ class LLM: ...@@ -924,6 +944,8 @@ class LLM:
/, /,
*, *,
use_tqdm: bool = True, use_tqdm: bool = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[EmbeddingRequestOutput]: ) -> list[EmbeddingRequestOutput]:
...@@ -938,6 +960,8 @@ class LLM: ...@@ -938,6 +960,8 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType` for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts. for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for prompt_adapter_request: Prompt Adapter request to use for
...@@ -953,6 +977,7 @@ class LLM: ...@@ -953,6 +977,7 @@ class LLM:
items = self.encode(prompts, items = self.encode(prompts,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
......
...@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema = self.response_format.json_schema json_schema = self.response_format.json_schema
assert json_schema is not None assert json_schema is not None
self.guided_json = json_schema.json_schema self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"
guided_decoding = GuidedDecodingParams.from_optional( guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json, json=self._get_guided_json_from_tool() or self.guided_json,
...@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): ...@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# doc: end-embedding-extra-params # doc: end-embedding-extra-params
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
class EmbeddingChatRequest(OpenAIBaseModel): class EmbeddingChatRequest(OpenAIBaseModel):
...@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel): ...@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return data return data
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
......
...@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams ...@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
truncate_tool_call_ids) truncate_tool_call_ids,
validate_request_params)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing):
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request) truncate_tool_call_ids(request)
validate_request_params(request)
if (request.tool_choice == "auto" and if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None) not (self.enable_auto_tools and tool_parser is not None)
......
...@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
return error_check_ret return error_check_ret
encoding_format = request.encoding_format encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model) model_name = self._get_model_name(request.model)
request_id = f"embd-{self._base_request_id(raw_request)}" request_id = f"embd-{self._base_request_id(raw_request)}"
...@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len." "greater than max_model_len."
" Please, select a smaller truncation size.") " Please, select a smaller truncation size.")
pooling_params = request.to_pooling_params()
try:
pooling_params.verify(self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
try: try:
( (
lora_request, lora_request,
...@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try: try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
......
...@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception): ...@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception):
class PythonicToolParser(ToolParser): class PythonicToolParser(ToolParser):
""" """
Tool call parser for models that produce tool calls in a pythonic style, Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 models. such as Llama 3.2 and Llama 4 models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
""" """
......
...@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]: ...@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]:
# partial_json_parser doesn't support extra data and # partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON # JSONDecoder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
try: try:
return (partial_json_parser.loads(input_str, flags), len(input_str)) return (partial_json_parser.loads(input_str, flags), len(input_str))
......
...@@ -23,6 +23,7 @@ if TYPE_CHECKING: ...@@ -23,6 +23,7 @@ if TYPE_CHECKING:
VLLM_USE_TC_PAGED_ATTN: bool = False VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_SPEC_DECODE_EAGER: bool = False VLLM_SPEC_DECODE_EAGER: bool = False
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_FLASH_ATTN_VERSION: Optional[int] = None VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
...@@ -115,6 +116,7 @@ if TYPE_CHECKING: ...@@ -115,6 +116,7 @@ if TYPE_CHECKING:
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
def get_default_cache_root(): def get_default_cache_root():
...@@ -289,6 +291,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -289,6 +291,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# when using the flash-attention backend. # when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION": "VLLM_FLASH_ATTN_VERSION":
lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)), lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)),
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_ENFORCE_EAGER_BS_THRESHOLD":
lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")),
# Internal flag to enable Dynamo fullgraph capture # Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
...@@ -716,6 +722,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -716,6 +722,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
# Use model_redirect to redirect the model name to a local folder. # Use model_redirect to redirect the model name to a local folder.
# `model_redirect` can be a json file mapping the model between
# repo_id and local folder:
# {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"}
# or a space separated values table file:
# meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B
"VLLM_MODEL_REDIRECT_PATH": "VLLM_MODEL_REDIRECT_PATH":
lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None), lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None),
...@@ -743,6 +754,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -743,6 +754,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Allow use of DeepGemm kernels for fused moe ops. # Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM": "VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB":
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): ...@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
and len(packed_modules_list) == 3) and len(packed_modules_list) == 3)
#TODO: Implement this
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
pass
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None: def __init__(self, base_layer: RowParallelLinear) -> None:
......
...@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager):
self._last_mapping: Optional[LoRAMapping] = None self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self self.model.lora_manager = self
self.adapter_type = 'LoRa' self.adapter_type = 'LoRA'
@property @property
def capacity(self) -> int: def capacity(self) -> int:
......
...@@ -111,7 +111,7 @@ class LoRAKernelMeta: ...@@ -111,7 +111,7 @@ class LoRAKernelMeta:
# active_lora_ids, num_tokens_per_lora # active_lora_ids, num_tokens_per_lora
lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping, lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping,
sorted=False, sorted=True,
return_counts=True) return_counts=True)
self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids,
non_blocking=True) non_blocking=True)
......
...@@ -33,6 +33,12 @@ def maybe_backend_fallback( ...@@ -33,6 +33,12 @@ def maybe_backend_fallback(
logger.warning("%s Falling back to use %s instead.", message, fallback) logger.warning("%s Falling back to use %s instead.", message, fallback)
guided_params.backend = fallback guided_params.backend = fallback
# `auto` was added for V1 to explicitly declare a mode that has fallbacks
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
# fallbacks enabled for that and it is the V0 default.
if guided_params.backend == "auto":
guided_params.backend = "xgrammar"
# lm-format-enforce doesn't support grammar, fallback to xgrammar # lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend_name == "lm-format-enforcer": if guided_params.backend_name == "lm-format-enforcer":
if guided_params.grammar is not None: if guided_params.grammar is not None:
...@@ -53,14 +59,9 @@ def maybe_backend_fallback( ...@@ -53,14 +59,9 @@ def maybe_backend_fallback(
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed) xgr_installed)
# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
fallback_or_error(
guided_params,
"xgrammar does not support regex guided decoding.", "outlines")
# xgrammar doesn't support some JSON schema features # xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None if (guided_params.json is not None and
and has_xgrammar_unsupported_json_features(guided_params.json)): has_xgrammar_unsupported_json_features(guided_params.json)):
fallback_or_error( fallback_or_error(
guided_params, guided_params,
"xgrammar does not support advanced JSON schema features like " "xgrammar does not support advanced JSON schema features like "
......
...@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool: ...@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if "pattern" in obj: if "pattern" in obj:
return True return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges # Check for numeric ranges
if obj.get("type") in ("integer", "number") and any( if obj.get("type") in ("integer", "number") and any(
key in obj for key in [ key in obj for key in [
......
...@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List ...@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List
import torch import torch
import vllm.envs
from vllm.logger import init_logger from vllm.logger import init_logger
try: try:
...@@ -131,8 +132,13 @@ class GrammarCompilerCache: ...@@ -131,8 +132,13 @@ class GrammarCompilerCache:
encoded_vocab=config_data.encoded_vocab, encoded_vocab=config_data.encoded_vocab,
metadata=config_data.metadata, metadata=config_data.metadata,
) )
cache_size = vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024
cls._cache[cache_key] = xgr.GrammarCompiler( cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=config.max_threads) tokenizer_info,
max_threads=config.max_threads,
cache_enabled=True,
cache_limit_bytes=cache_size,
)
return cls._cache[cache_key] return cls._cache[cache_key]
...@@ -146,6 +152,7 @@ class GrammarConfig: ...@@ -146,6 +152,7 @@ class GrammarConfig:
grammar_str: str | None = None grammar_str: str | None = None
json_object: bool | None = None json_object: bool | None = None
any_whitespace: bool = True any_whitespace: bool = True
regex_str: str | None = None
max_threads: int = 8 max_threads: int = 8
@classmethod @classmethod
...@@ -249,6 +256,13 @@ class GrammarConfig: ...@@ -249,6 +256,13 @@ class GrammarConfig:
max_threads=max_threads, max_threads=max_threads,
tokenizer_data=tokenizer_data, tokenizer_data=tokenizer_data,
) )
elif guided_params.regex:
return cls(
regex_str=guided_params.regex,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
else: else:
raise ValueError( raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar" "Currently only support JSON and EBNF grammar mode for xgrammar"
...@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor: ...@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor:
self.ctx = compiler\ self.ctx = compiler\
.compile_json_schema('{"type": "object"}', .compile_json_schema('{"type": "object"}',
any_whitespace=any_whitespace) any_whitespace=any_whitespace)
elif self.config.regex_str:
self.ctx = compiler.compile_regex(self.config.regex_str)
else: else:
raise ValueError( raise ValueError(
"Invalid configuration for xgrammar logits processor") "Invalid configuration for xgrammar logits processor")
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
...@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( ...@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8) _valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size) moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -479,51 +482,53 @@ def fused_moe_kernel_gptq_awq( ...@@ -479,51 +482,53 @@ def fused_moe_kernel_gptq_awq(
@triton.jit @triton.jit
def fused_moe_kernel( def fused_moe_kernel(
# Pointers to matrices # Pointers to matrices
a_ptr, a_ptr,
b_ptr, b_ptr,
c_ptr, c_ptr,
a_scale_ptr, a_scale_ptr,
b_scale_ptr, b_scale_ptr,
topk_weights_ptr, topk_weights_ptr,
sorted_token_ids_ptr, sorted_token_ids_ptr,
expert_ids_ptr, expert_ids_ptr,
num_tokens_post_padded_ptr, num_tokens_post_padded_ptr,
# Matrix dimensions # Matrix dimensions
N, N,
K, K,
EM, EM,
num_valid_tokens, num_valid_tokens,
# The stride variables represent how much to increase the ptr by when # The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
# (A has M rows). # (A has M rows).
stride_am, stride_am,
stride_ak, stride_ak,
stride_be, stride_be,
stride_bk, stride_bk,
stride_bn, stride_bn,
stride_cm, stride_cm,
stride_cn, stride_cn,
stride_asm, stride_asm,
stride_ask, stride_ask,
stride_bse, stride_bse,
stride_bsk, stride_bsk,
stride_bsn, stride_bsn,
# Block size for block-wise quantization # Block size for block-wise quantization
group_n: tl.constexpr, group_n: tl.constexpr,
group_k: tl.constexpr, group_k: tl.constexpr,
# Meta-parameters # Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr, top_k: tl.constexpr,
compute_type: tl.constexpr, compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr, use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr, use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr): use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices. token and expert matrices.
...@@ -605,11 +610,22 @@ def fused_moe_kernel( ...@@ -605,11 +610,22 @@ def fused_moe_kernel(
b_scale = tl.load(b_scale_ptrs) b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8 or use_int8_w8a8: if use_fp8_w8a8 or use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n offs_bsn = offs_bn // group_n
b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse +
offs_bsn * stride_bsn) offs_bsn * stride_bsn)
# channel-wise
elif per_channel_quant:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:,
None]
# tensor-wise
else: else:
a_scale = tl.load(a_scale_ptr) a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts) b_scale = tl.load(b_scale_ptr + off_experts)
...@@ -645,7 +661,11 @@ def fused_moe_kernel( ...@@ -645,7 +661,11 @@ def fused_moe_kernel(
accumulator += tl.dot(a, b) * a_scale[:, accumulator += tl.dot(a, b) * a_scale[:,
None] * b_scale[None, :] None] * b_scale[None, :]
else: else:
accumulator = tl.dot(a, b, acc=accumulator) if use_fp8_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
else: else:
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block. # Advance the ptrs to the next K block.
...@@ -693,33 +713,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -693,33 +713,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None: use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights is not None or not mul_routed_weight assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1 assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.shape[0] M = A.shape[0]
num_tokens = M * top_k num_tokens = M * top_k
...@@ -887,7 +887,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -887,7 +887,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
#BLOCK_SIZE_K=BLOCK_SIZE_K, per_channel_quant=per_channel_quant,
# BLOCK_SIZE_K=BLOCK_SIZE_K,
**config, **config,
) )
...@@ -1263,6 +1264,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1263,6 +1264,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1275,9 +1277,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1275,9 +1277,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, per_channel_quant, global_num_experts, expert_map,
a2_scale, block_shape, use_nn_moe) w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1292,6 +1295,7 @@ def inplace_fused_experts_fake( ...@@ -1292,6 +1295,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1310,6 +1314,7 @@ direct_register_custom_op( ...@@ -1310,6 +1314,7 @@ direct_register_custom_op(
op_func=inplace_fused_experts, op_func=inplace_fused_experts,
mutates_args=["hidden_states"], mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake, fake_impl=inplace_fused_experts_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
) )
...@@ -1325,6 +1330,7 @@ def outplace_fused_experts( ...@@ -1325,6 +1330,7 @@ def outplace_fused_experts(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1337,7 +1343,8 @@ def outplace_fused_experts( ...@@ -1337,7 +1343,8 @@ def outplace_fused_experts(
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, per_channel_quant,
global_num_experts, expert_map, w1_scale, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe)
...@@ -1354,6 +1361,7 @@ def outplace_fused_experts_fake( ...@@ -1354,6 +1361,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1372,6 +1380,7 @@ direct_register_custom_op( ...@@ -1372,6 +1380,7 @@ direct_register_custom_op(
op_func=outplace_fused_experts, op_func=outplace_fused_experts,
mutates_args=[], mutates_args=[],
fake_impl=outplace_fused_experts_fake, fake_impl=outplace_fused_experts_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
) )
...@@ -1405,6 +1414,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1405,6 +1414,7 @@ def fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1448,6 +1458,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1448,6 +1458,7 @@ def fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale, w1_scale=w1_scale,
...@@ -1460,6 +1471,59 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1460,6 +1471,59 @@ def fused_experts(hidden_states: torch.Tensor,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
def moe_kernel_prepare_input(
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
# activation channel-wise int8 quantization
assert (per_channel_quant
), "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
return A, A_scale
def fused_experts_impl(hidden_states: torch.Tensor, def fused_experts_impl(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -1472,6 +1536,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1472,6 +1536,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1583,15 +1648,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1583,15 +1648,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qcurr_hidden_states, a1q_scale = _fp8_quantize(
curr_hidden_states, a1_scale, block_shape)
else:
qcurr_hidden_states = curr_hidden_states
a1q_scale = a1_scale
if use_int8_w8a8: if use_int8_w8a8:
m=curr_hidden_states.shape[0] m=curr_hidden_states.shape[0]
...@@ -1620,6 +1676,18 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1620,6 +1676,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
"num_warps": 4 "num_warps": 4
} }
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
A=curr_hidden_states,
B=w1,
A_scale=a1_scale,
B_scale=w1_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
if use_int4_w4a16: if use_int4_w4a16:
sorted_token_ids, expert_ids, num_tokens_post_padded = ( sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
...@@ -1632,7 +1700,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1632,7 +1700,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel(qcurr_hidden_states, invoke_fused_moe_kernel(qcurr_hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
a1q_scale, qa1_scale,
w1_scale, w1_scale,
w1_zp, w1_zp,
curr_topk_weights, curr_topk_weights,
...@@ -1647,6 +1715,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1647,6 +1715,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
...@@ -1658,15 +1727,18 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1658,15 +1727,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
else: else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}") raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
A=intermediate_cache2,
if use_fp8_w8a8: B=w2,
qintermediate_cache2, a2q_scale = _fp8_quantize( A_scale=a2_scale,
intermediate_cache2, a2_scale, block_shape) B_scale=w2_scale,
else: use_fp8_w8a8=use_fp8_w8a8,
qintermediate_cache2 = intermediate_cache2 use_int8_w8a8=use_int8_w8a8,
a2q_scale = a2_scale use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
if use_int8_w8a8: if use_int8_w8a8:
m=curr_hidden_states.shape[0] m=curr_hidden_states.shape[0]
...@@ -1698,7 +1770,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1698,7 +1770,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel(qintermediate_cache2, invoke_fused_moe_kernel(qintermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
a2q_scale, qa2_scale,
w2_scale, w2_scale,
w2_zp, w2_zp,
curr_topk_weights, curr_topk_weights,
...@@ -1713,6 +1785,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1713,6 +1785,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
...@@ -1739,6 +1812,7 @@ def fused_moe( ...@@ -1739,6 +1812,7 @@ def fused_moe(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1772,6 +1846,8 @@ def fused_moe( ...@@ -1772,6 +1846,8 @@ def fused_moe(
note: Deepseekv2 model uses grouped_topk note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2. activation to compute the inner products for w1 and w2.
Defaults to False. Defaults to False.
...@@ -1821,6 +1897,7 @@ def fused_moe( ...@@ -1821,6 +1897,7 @@ def fused_moe(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale, w1_scale=w1_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment