Commit 9c4ecf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-ori

parents bfc2d6f7 dc1b4a6f
...@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand): ...@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand):
bench_parser = subparsers.add_parser( bench_parser = subparsers.add_parser(
"bench", "bench",
help="vLLM bench subcommand.", help="vLLM bench subcommand.",
description="vLLM bench subcommand.",
usage="vllm bench <bench_type> [options]") usage="vllm bench <bench_type> [options]")
bench_subparsers = bench_parser.add_subparsers(required=True, bench_subparsers = bench_parser.add_subparsers(required=True,
dest="bench_type") dest="bench_type")
......
...@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand): ...@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand):
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
chat_parser = subparsers.add_parser( chat_parser = subparsers.add_parser(
"chat", "chat",
help="Generate chat completions via the running API server", help="Generate chat completions via the running API server.",
description="Generate chat completions via the running API server.",
usage="vllm chat [options]") usage="vllm chat [options]")
_add_query_options(chat_parser) _add_query_options(chat_parser)
chat_parser.add_argument( chat_parser.add_argument(
...@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand): ...@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand):
complete_parser = subparsers.add_parser( complete_parser = subparsers.add_parser(
"complete", "complete",
help=("Generate text completions based on the given prompt " help=("Generate text completions based on the given prompt "
"via the running API server"), "via the running API server."),
description=("Generate text completions based on the given prompt "
"via the running API server."),
usage="vllm complete [options]") usage="vllm complete [options]")
_add_query_options(complete_parser) _add_query_options(complete_parser)
return complete_parser return complete_parser
......
...@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand): ...@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand):
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser( serve_parser = subparsers.add_parser(
"serve", "serve",
help="Start the vLLM OpenAI Compatible API server", help="Start the vLLM OpenAI Compatible API server.",
description="Start the vLLM OpenAI Compatible API server.",
usage="vllm serve [model_tag] [options]") usage="vllm serve [model_tag] [options]")
serve_parser.add_argument("model_tag", serve_parser.add_argument("model_tag",
type=str, type=str,
......
...@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload ...@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
import cloudpickle import cloudpickle
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
...@@ -117,6 +117,9 @@ class LLM: ...@@ -117,6 +117,9 @@ class LLM:
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig` disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing. disable_async_output_proc: Disable async output processing.
This may result in lower performance. This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the HuggingFace config. If a callable, it is called to update the
HuggingFace config. HuggingFace config.
...@@ -177,6 +180,7 @@ class LLM: ...@@ -177,6 +180,7 @@ class LLM:
max_seq_len_to_capture: int = 8192, max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None, hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None,
# After positional args are removed, move this right below `model` # After positional args are removed, move this right below `model`
...@@ -232,6 +236,7 @@ class LLM: ...@@ -232,6 +236,7 @@ class LLM:
max_seq_len_to_capture=max_seq_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
hf_token=hf_token,
hf_overrides=hf_overrides, hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config, override_pooler_config=override_pooler_config,
...@@ -531,6 +536,16 @@ class LLM: ...@@ -531,6 +536,16 @@ class LLM:
tokenizer.eos_token_id, tokenizer.eos_token_id,
length_penalty) length_penalty)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if any("multi_modal_data" in prompt
and prompt["multi_modal_data"] is not None
for prompt in prompts):
logger.warning(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!")
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step # generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation # following the huggingface transformers implementation
...@@ -906,6 +921,11 @@ class LLM: ...@@ -906,6 +921,11 @@ class LLM:
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
elif isinstance(pooling_params, PoolingParams):
pooling_params.verify(self.llm_engine.model_config)
else:
for pooling_param in pooling_params:
pooling_param.verify(self.llm_engine.model_config)
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, prompts=parsed_prompts,
...@@ -924,6 +944,8 @@ class LLM: ...@@ -924,6 +944,8 @@ class LLM:
/, /,
*, *,
use_tqdm: bool = True, use_tqdm: bool = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[EmbeddingRequestOutput]: ) -> list[EmbeddingRequestOutput]:
...@@ -938,6 +960,8 @@ class LLM: ...@@ -938,6 +960,8 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType` for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts. for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for prompt_adapter_request: Prompt Adapter request to use for
...@@ -953,6 +977,7 @@ class LLM: ...@@ -953,6 +977,7 @@ class LLM:
items = self.encode(prompts, items = self.encode(prompts,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
......
...@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema = self.response_format.json_schema json_schema = self.response_format.json_schema
assert json_schema is not None assert json_schema is not None
self.guided_json = json_schema.json_schema self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"
guided_decoding = GuidedDecodingParams.from_optional( guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json, json=self._get_guided_json_from_tool() or self.guided_json,
...@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): ...@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# doc: end-embedding-extra-params # doc: end-embedding-extra-params
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
class EmbeddingChatRequest(OpenAIBaseModel): class EmbeddingChatRequest(OpenAIBaseModel):
...@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel): ...@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return data return data
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
......
...@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams ...@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
truncate_tool_call_ids) truncate_tool_call_ids,
validate_request_params)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing):
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request) truncate_tool_call_ids(request)
validate_request_params(request)
if (request.tool_choice == "auto" and if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None) not (self.enable_auto_tools and tool_parser is not None)
......
...@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
return error_check_ret return error_check_ret
encoding_format = request.encoding_format encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model) model_name = self._get_model_name(request.model)
request_id = f"embd-{self._base_request_id(raw_request)}" request_id = f"embd-{self._base_request_id(raw_request)}"
...@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len." "greater than max_model_len."
" Please, select a smaller truncation size.") " Please, select a smaller truncation size.")
pooling_params = request.to_pooling_params()
try:
pooling_params.verify(self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
try: try:
( (
lora_request, lora_request,
...@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try: try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
......
...@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception): ...@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception):
class PythonicToolParser(ToolParser): class PythonicToolParser(ToolParser):
""" """
Tool call parser for models that produce tool calls in a pythonic style, Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 models. such as Llama 3.2 and Llama 4 models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
""" """
......
...@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]: ...@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]:
# partial_json_parser doesn't support extra data and # partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON # JSONDecoder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
try: try:
return (partial_json_parser.loads(input_str, flags), len(input_str)) return (partial_json_parser.loads(input_str, flags), len(input_str))
......
...@@ -106,6 +106,7 @@ if TYPE_CHECKING: ...@@ -106,6 +106,7 @@ if TYPE_CHECKING:
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
def get_default_cache_root(): def get_default_cache_root():
...@@ -665,6 +666,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -665,6 +666,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
# Use model_redirect to redirect the model name to a local folder. # Use model_redirect to redirect the model name to a local folder.
# `model_redirect` can be a json file mapping the model between
# repo_id and local folder:
# {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"}
# or a space separated values table file:
# meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B
"VLLM_MODEL_REDIRECT_PATH": "VLLM_MODEL_REDIRECT_PATH":
lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None), lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None),
...@@ -692,6 +698,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -692,6 +698,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Allow use of DeepGemm kernels for fused moe ops. # Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM": "VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB":
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): ...@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
and len(packed_modules_list) == 3) and len(packed_modules_list) == 3)
#TODO: Implement this
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
pass
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None: def __init__(self, base_layer: RowParallelLinear) -> None:
......
...@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager):
self._last_mapping: Optional[LoRAMapping] = None self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self self.model.lora_manager = self
self.adapter_type = 'LoRa' self.adapter_type = 'LoRA'
@property @property
def capacity(self) -> int: def capacity(self) -> int:
......
...@@ -111,7 +111,7 @@ class LoRAKernelMeta: ...@@ -111,7 +111,7 @@ class LoRAKernelMeta:
# active_lora_ids, num_tokens_per_lora # active_lora_ids, num_tokens_per_lora
lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping, lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping,
sorted=False, sorted=True,
return_counts=True) return_counts=True)
self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids,
non_blocking=True) non_blocking=True)
......
...@@ -33,6 +33,12 @@ def maybe_backend_fallback( ...@@ -33,6 +33,12 @@ def maybe_backend_fallback(
logger.warning("%s Falling back to use %s instead.", message, fallback) logger.warning("%s Falling back to use %s instead.", message, fallback)
guided_params.backend = fallback guided_params.backend = fallback
# `auto` was added for V1 to explicitly declare a mode that has fallbacks
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
# fallbacks enabled for that and it is the V0 default.
if guided_params.backend == "auto":
guided_params.backend = "xgrammar"
# lm-format-enforce doesn't support grammar, fallback to xgrammar # lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend_name == "lm-format-enforcer": if guided_params.backend_name == "lm-format-enforcer":
if guided_params.grammar is not None: if guided_params.grammar is not None:
...@@ -53,14 +59,9 @@ def maybe_backend_fallback( ...@@ -53,14 +59,9 @@ def maybe_backend_fallback(
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed) xgr_installed)
# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
fallback_or_error(
guided_params,
"xgrammar does not support regex guided decoding.", "outlines")
# xgrammar doesn't support some JSON schema features # xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None if (guided_params.json is not None and
and has_xgrammar_unsupported_json_features(guided_params.json)): has_xgrammar_unsupported_json_features(guided_params.json)):
fallback_or_error( fallback_or_error(
guided_params, guided_params,
"xgrammar does not support advanced JSON schema features like " "xgrammar does not support advanced JSON schema features like "
......
...@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool: ...@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if "pattern" in obj: if "pattern" in obj:
return True return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges # Check for numeric ranges
if obj.get("type") in ("integer", "number") and any( if obj.get("type") in ("integer", "number") and any(
key in obj for key in [ key in obj for key in [
......
...@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List ...@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List
import torch import torch
import vllm.envs
from vllm.logger import init_logger from vllm.logger import init_logger
try: try:
...@@ -131,8 +132,13 @@ class GrammarCompilerCache: ...@@ -131,8 +132,13 @@ class GrammarCompilerCache:
encoded_vocab=config_data.encoded_vocab, encoded_vocab=config_data.encoded_vocab,
metadata=config_data.metadata, metadata=config_data.metadata,
) )
cache_size = vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024
cls._cache[cache_key] = xgr.GrammarCompiler( cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=config.max_threads) tokenizer_info,
max_threads=config.max_threads,
cache_enabled=True,
cache_limit_bytes=cache_size,
)
return cls._cache[cache_key] return cls._cache[cache_key]
...@@ -146,6 +152,7 @@ class GrammarConfig: ...@@ -146,6 +152,7 @@ class GrammarConfig:
grammar_str: str | None = None grammar_str: str | None = None
json_object: bool | None = None json_object: bool | None = None
any_whitespace: bool = True any_whitespace: bool = True
regex_str: str | None = None
max_threads: int = 8 max_threads: int = 8
@classmethod @classmethod
...@@ -249,6 +256,13 @@ class GrammarConfig: ...@@ -249,6 +256,13 @@ class GrammarConfig:
max_threads=max_threads, max_threads=max_threads,
tokenizer_data=tokenizer_data, tokenizer_data=tokenizer_data,
) )
elif guided_params.regex:
return cls(
regex_str=guided_params.regex,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
else: else:
raise ValueError( raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar" "Currently only support JSON and EBNF grammar mode for xgrammar"
...@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor: ...@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor:
self.ctx = compiler\ self.ctx = compiler\
.compile_json_schema('{"type": "object"}', .compile_json_schema('{"type": "object"}',
any_whitespace=any_whitespace) any_whitespace=any_whitespace)
elif self.config.regex_str:
self.ctx = compiler.compile_regex(self.config.regex_str)
else: else:
raise ValueError( raise ValueError(
"Invalid configuration for xgrammar logits processor") "Invalid configuration for xgrammar logits processor")
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
...@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( ...@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8) _valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size) moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -251,50 +254,53 @@ def fused_moe_kernel_gptq_awq( ...@@ -251,50 +254,53 @@ def fused_moe_kernel_gptq_awq(
@triton.jit @triton.jit
def fused_moe_kernel( def fused_moe_kernel(
# Pointers to matrices # Pointers to matrices
a_ptr, a_ptr,
b_ptr, b_ptr,
c_ptr, c_ptr,
a_scale_ptr, a_scale_ptr,
b_scale_ptr, b_scale_ptr,
topk_weights_ptr, topk_weights_ptr,
sorted_token_ids_ptr, sorted_token_ids_ptr,
expert_ids_ptr, expert_ids_ptr,
num_tokens_post_padded_ptr, num_tokens_post_padded_ptr,
# Matrix dimensions # Matrix dimensions
N, N,
K, K,
EM, EM,
num_valid_tokens, num_valid_tokens,
# The stride variables represent how much to increase the ptr by when # The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
# (A has M rows). # (A has M rows).
stride_am, stride_am,
stride_ak, stride_ak,
stride_be, stride_be,
stride_bk, stride_bk,
stride_bn, stride_bn,
stride_cm, stride_cm,
stride_cn, stride_cn,
stride_asm, stride_asm,
stride_ask, stride_ask,
stride_bse, stride_bse,
stride_bsk, stride_bsk,
stride_bsn, stride_bsn,
# Block size for block-wise quantization # Block size for block-wise quantization
group_n: tl.constexpr, group_n: tl.constexpr,
group_k: tl.constexpr, group_k: tl.constexpr,
# Meta-parameters # Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr, top_k: tl.constexpr,
compute_type: tl.constexpr, compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr, use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr): use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices. token and expert matrices.
...@@ -385,12 +391,23 @@ def fused_moe_kernel( ...@@ -385,12 +391,23 @@ def fused_moe_kernel(
None, :] * stride_bsn None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs) b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8: if use_fp8_w8a8 or use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n offs_bsn = offs_bn // group_n
b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse +
offs_bsn * stride_bsn) offs_bsn * stride_bsn)
# channel-wise
elif per_channel_quant:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:,
None]
# tensor-wise
else: else:
a_scale = tl.load(a_scale_ptr) a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts) b_scale = tl.load(b_scale_ptr + off_experts)
...@@ -414,7 +431,7 @@ def fused_moe_kernel( ...@@ -414,7 +431,7 @@ def fused_moe_kernel(
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_int8_w8a16: if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8: elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k offs_ks = k_start // group_k
...@@ -426,7 +443,11 @@ def fused_moe_kernel( ...@@ -426,7 +443,11 @@ def fused_moe_kernel(
accumulator += tl.dot(a, b) * a_scale[:, accumulator += tl.dot(a, b) * a_scale[:,
None] * b_scale[None, :] None] * b_scale[None, :]
else: else:
accumulator = tl.dot(a, b, acc=accumulator) if use_fp8_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
else: else:
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block. # Advance the ptrs to the next K block.
...@@ -440,7 +461,7 @@ def fused_moe_kernel( ...@@ -440,7 +461,7 @@ def fused_moe_kernel(
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16: if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type) accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8: elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
else: else:
...@@ -471,28 +492,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -471,28 +492,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config: Dict[str, Any], config: Dict[str, Any],
compute_type: tl.dtype, compute_type: tl.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None: use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights is not None or not mul_routed_weight assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1 assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.shape[0] M = A.shape[0]
num_tokens = M * top_k num_tokens = M * top_k
...@@ -619,7 +628,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -619,7 +628,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_K=BLOCK_SIZE_K,
**config, **config,
) )
...@@ -981,8 +992,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -981,8 +992,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
activation: Optional[str] = None, activation: Optional[str] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -995,9 +1008,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -995,9 +1008,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, global_num_experts, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, per_channel_quant, global_num_experts, expert_map,
a2_scale, block_shape, use_nn_moe) w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1009,8 +1023,10 @@ def inplace_fused_experts_fake( ...@@ -1009,8 +1023,10 @@ def inplace_fused_experts_fake(
activation: Optional[str] = None, activation: Optional[str] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1029,6 +1045,7 @@ direct_register_custom_op( ...@@ -1029,6 +1045,7 @@ direct_register_custom_op(
op_func=inplace_fused_experts, op_func=inplace_fused_experts,
mutates_args=["hidden_states"], mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake, fake_impl=inplace_fused_experts_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
) )
...@@ -1041,8 +1058,10 @@ def outplace_fused_experts( ...@@ -1041,8 +1058,10 @@ def outplace_fused_experts(
activation: Optional[str] = None, activation: Optional[str] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1055,7 +1074,8 @@ def outplace_fused_experts( ...@@ -1055,7 +1074,8 @@ def outplace_fused_experts(
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, per_channel_quant,
global_num_experts, expert_map, w1_scale, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe)
...@@ -1069,8 +1089,10 @@ def outplace_fused_experts_fake( ...@@ -1069,8 +1089,10 @@ def outplace_fused_experts_fake(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: Optional[str] = None, activation: Optional[str] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1089,6 +1111,7 @@ direct_register_custom_op( ...@@ -1089,6 +1111,7 @@ direct_register_custom_op(
op_func=outplace_fused_experts, op_func=outplace_fused_experts,
mutates_args=[], mutates_args=[],
fake_impl=outplace_fused_experts_fake, fake_impl=outplace_fused_experts_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
) )
...@@ -1119,8 +1142,10 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1119,8 +1142,10 @@ def fused_experts(hidden_states: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1160,8 +1185,10 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1160,8 +1185,10 @@ def fused_experts(hidden_states: torch.Tensor,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale, w1_scale=w1_scale,
...@@ -1174,6 +1201,59 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1174,6 +1201,59 @@ def fused_experts(hidden_states: torch.Tensor,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
def moe_kernel_prepare_input(
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
# activation channel-wise int8 quantization
assert (per_channel_quant
), "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
return A, A_scale
def fused_experts_impl(hidden_states: torch.Tensor, def fused_experts_impl(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -1183,8 +1263,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1183,8 +1263,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1294,14 +1376,17 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1294,14 +1376,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
A=curr_hidden_states,
if use_fp8_w8a8: B=w1,
qcurr_hidden_states, a1q_scale = _fp8_quantize( A_scale=a1_scale,
curr_hidden_states, a1_scale, block_shape) B_scale=w1_scale,
else: use_fp8_w8a8=use_fp8_w8a8,
qcurr_hidden_states = curr_hidden_states use_int8_w8a8=use_int8_w8a8,
a1q_scale = a1_scale use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
sorted_token_ids, expert_ids, num_tokens_post_padded = ( sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
...@@ -1310,7 +1395,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1310,7 +1395,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel(qcurr_hidden_states, invoke_fused_moe_kernel(qcurr_hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
a1q_scale, qa1_scale,
w1_scale, w1_scale,
w1_zp, w1_zp,
curr_topk_weights, curr_topk_weights,
...@@ -1322,8 +1407,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1322,8 +1407,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
...@@ -1336,19 +1423,22 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1336,19 +1423,22 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else: else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}") raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
A=intermediate_cache2,
if use_fp8_w8a8: B=w2,
qintermediate_cache2, a2q_scale = _fp8_quantize( A_scale=a2_scale,
intermediate_cache2, a2_scale, block_shape) B_scale=w2_scale,
else: use_fp8_w8a8=use_fp8_w8a8,
qintermediate_cache2 = intermediate_cache2 use_int8_w8a8=use_int8_w8a8,
a2q_scale = a2_scale use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
invoke_fused_moe_kernel(qintermediate_cache2, invoke_fused_moe_kernel(qintermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
a2q_scale, qa2_scale,
w2_scale, w2_scale,
w2_zp, w2_zp,
curr_topk_weights, curr_topk_weights,
...@@ -1360,8 +1450,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1360,8 +1450,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
...@@ -1385,8 +1477,10 @@ def fused_moe( ...@@ -1385,8 +1477,10 @@ def fused_moe(
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1420,6 +1514,8 @@ def fused_moe( ...@@ -1420,6 +1514,8 @@ def fused_moe(
note: Deepseekv2 model uses grouped_topk note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2. activation to compute the inner products for w1 and w2.
Defaults to False. Defaults to False.
...@@ -1466,8 +1562,10 @@ def fused_moe( ...@@ -1466,8 +1562,10 @@ def fused_moe(
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale, w1_scale=w1_scale,
......
...@@ -192,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -192,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
def forward_cuda( def forward_cuda(
...@@ -458,7 +458,7 @@ class FusedMoE(torch.nn.Module): ...@@ -458,7 +458,7 @@ class FusedMoE(torch.nn.Module):
# Use expert parallelism instead of tensor parallelism? # Use expert parallelism instead of tensor parallelism?
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
use_ep = (vllm_config.parallel_config.enable_expert_parallel use_ep = (vllm_config.parallel_config.enable_expert_parallel
and self.tp_size > 1) and self.tp_size * self.dp_size > 1)
# For smuggling this layer into the fused moe custom op # For smuggling this layer into the fused moe custom op
self.use_direct_call = self.dp_size == 1 self.use_direct_call = self.dp_size == 1
...@@ -542,7 +542,9 @@ class FusedMoE(torch.nn.Module): ...@@ -542,7 +542,9 @@ class FusedMoE(torch.nn.Module):
} }
# need full intermediate size pre-sharding for WNA16 act order # need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
...@@ -691,9 +693,10 @@ class FusedMoE(torch.nn.Module): ...@@ -691,9 +693,10 @@ class FusedMoE(torch.nn.Module):
# compressed-tensors checkpoints with packed weights are stored flipped # compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format # TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality # against known CompressionFormat enum values that have this quality
loaded_weight = loaded_weight.t().contiguous() if ( if self.quant_method.__class__.__name__ in (
self.quant_method.__class__.__name__ "CompressedTensorsWNA16MarlinMoEMethod",
== "CompressedTensorsWNA16MoEMethod") else loaded_weight "CompressedTensorsWNA16MoEMethod"):
loaded_weight = loaded_weight.t().contiguous()
if shard_id not in ("w1", "w2", "w3"): if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but " raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
......
...@@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase): ...@@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
prefix=f"{prefix}.kv_proj_encoder") prefix=f"{prefix}.kv_proj_encoder")
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.q_size = self.q_proj_decoder.output_size_per_partition
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
if bias: if bias:
...@@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase): ...@@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
else: else:
self.bias = None self.bias = None
def process_weights_after_loading(self):
for layer in self.proj.values():
if self.quant_method is not None:
self.quant_method.process_weights_after_loading(layer)
@property @property
def q_proj_decoder(self) -> ColumnParallelLinear: def q_proj_decoder(self) -> ColumnParallelLinear:
layer = self.proj["q_proj_decoder"] layer = self.proj["q_proj_decoder"]
for name, param in self.named_parameters(): for name, param in self.named_parameters():
target_param = getattr(layer, name) target_param = getattr(layer, name, None)
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder") if target_param is not None:
self.sync_weight_attrs(param,
target_param,
mode="q_proj_decoder")
return layer return layer
@property @property
def kv_proj_encoder(self) -> QKVParallelLinear: def kv_proj_encoder(self) -> QKVParallelLinear:
layer = self.proj["kv_proj_encoder"] layer = self.proj["kv_proj_encoder"]
for name, param in self.named_parameters(): for name, param in self.named_parameters():
target_param = getattr(layer, name) target_param = getattr(layer, name, None)
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder") if target_param is not None:
self.sync_weight_attrs(param,
target_param,
mode="kv_proj_encoder")
return layer return layer
def sync_weight_attrs( def sync_weight_attrs(
...@@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase): ...@@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
if loaded_shard_id == "q" else self.kv_proj_encoder) if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.select_proj_params(layer, param) target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args) if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
else:
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"in_features={self.input_size}" s = f"in_features={self.input_size}"
s += f", q_size={self.q_proj_decoder.output_size_per_partition}" s += f", q_size={self.q_size}"
s += f", kv_size={self.kv_size}" s += f", kv_size={self.kv_size}"
s += f", bias={self.bias is not None}" s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}" s += f", tp_size={get_tensor_model_parallel_world_size()}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment