Commit 9c4ecf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-ori

parents bfc2d6f7 dc1b4a6f
......@@ -33,6 +33,7 @@ class BenchmarkSubcommand(CLISubcommand):
bench_parser = subparsers.add_parser(
"bench",
help="vLLM bench subcommand.",
description="vLLM bench subcommand.",
usage="vllm bench <bench_type> [options]")
bench_subparsers = bench_parser.add_subparsers(required=True,
dest="bench_type")
......
......@@ -126,7 +126,8 @@ class ChatCommand(CLISubcommand):
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
chat_parser = subparsers.add_parser(
"chat",
help="Generate chat completions via the running API server",
help="Generate chat completions via the running API server.",
description="Generate chat completions via the running API server.",
usage="vllm chat [options]")
_add_query_options(chat_parser)
chat_parser.add_argument(
......@@ -162,7 +163,9 @@ class CompleteCommand(CLISubcommand):
complete_parser = subparsers.add_parser(
"complete",
help=("Generate text completions based on the given prompt "
"via the running API server"),
"via the running API server."),
description=("Generate text completions based on the given prompt "
"via the running API server."),
usage="vllm complete [options]")
_add_query_options(complete_parser)
return complete_parser
......
......@@ -34,7 +34,8 @@ class ServeSubcommand(CLISubcommand):
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser(
"serve",
help="Start the vLLM OpenAI Compatible API server",
help="Start the vLLM OpenAI Compatible API server.",
description="Start the vLLM OpenAI Compatible API server.",
usage="vllm serve [model_tag] [options]")
serve_parser.add_argument("model_tag",
type=str,
......
......@@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
import cloudpickle
import torch.nn as nn
from tqdm import tqdm
from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
......@@ -117,6 +117,9 @@ class LLM:
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
......@@ -177,6 +180,7 @@ class LLM:
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
......@@ -232,6 +236,7 @@ class LLM:
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
hf_token=hf_token,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
......@@ -531,6 +536,16 @@ class LLM:
tokenizer.eos_token_id,
length_penalty)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if any("multi_modal_data" in prompt
and prompt["multi_modal_data"] is not None
for prompt in prompts):
logger.warning(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!")
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
......@@ -906,6 +921,11 @@ class LLM:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
elif isinstance(pooling_params, PoolingParams):
pooling_params.verify(self.llm_engine.model_config)
else:
for pooling_param in pooling_params:
pooling_param.verify(self.llm_engine.model_config)
self._validate_and_add_requests(
prompts=parsed_prompts,
......@@ -924,6 +944,8 @@ class LLM:
/,
*,
use_tqdm: bool = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[EmbeddingRequestOutput]:
......@@ -938,6 +960,8 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
......@@ -953,6 +977,7 @@ class LLM:
items = self.encode(prompts,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
......
......@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"
guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
......@@ -1008,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# doc: end-embedding-extra-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
class EmbeddingChatRequest(OpenAIBaseModel):
......@@ -1070,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return data
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
......
......@@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
truncate_tool_call_ids)
truncate_tool_call_ids,
validate_request_params)
logger = init_logger(__name__)
......@@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing):
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request)
validate_request_params(request)
if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None)
......
......@@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model)
request_id = f"embd-{self._base_request_id(raw_request)}"
......@@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len."
" Please, select a smaller truncation size.")
pooling_params = request.to_pooling_params()
try:
pooling_params.verify(self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
try:
(
lora_request,
......@@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
......
......@@ -28,7 +28,7 @@ class _UnexpectedAstError(Exception):
class PythonicToolParser(ToolParser):
"""
Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 models.
such as Llama 3.2 and Llama 4 models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
"""
......
......@@ -98,7 +98,7 @@ def find_all_indices(string: str, substring: str) -> list[int]:
# partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON
# JSONDecoder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
......
......@@ -106,6 +106,7 @@ if TYPE_CHECKING:
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
def get_default_cache_root():
......@@ -665,6 +666,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
# Use model_redirect to redirect the model name to a local folder.
# `model_redirect` can be a json file mapping the model between
# repo_id and local folder:
# {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"}
# or a space separated values table file:
# meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B
"VLLM_MODEL_REDIRECT_PATH":
lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None),
......@@ -692,6 +698,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB":
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
}
# end-env-vars-definition
......
......@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
and len(packed_modules_list) == 3)
#TODO: Implement this
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
pass
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
......
......@@ -364,7 +364,7 @@ class LoRAModelManager(AdapterModelManager):
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
self.adapter_type = 'LoRa'
self.adapter_type = 'LoRA'
@property
def capacity(self) -> int:
......
......@@ -111,7 +111,7 @@ class LoRAKernelMeta:
# active_lora_ids, num_tokens_per_lora
lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping,
sorted=False,
sorted=True,
return_counts=True)
self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids,
non_blocking=True)
......
......@@ -33,6 +33,12 @@ def maybe_backend_fallback(
logger.warning("%s Falling back to use %s instead.", message, fallback)
guided_params.backend = fallback
# `auto` was added for V1 to explicitly declare a mode that has fallbacks
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
# fallbacks enabled for that and it is the V0 default.
if guided_params.backend == "auto":
guided_params.backend = "xgrammar"
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend_name == "lm-format-enforcer":
if guided_params.grammar is not None:
......@@ -53,14 +59,9 @@ def maybe_backend_fallback(
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed)
# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
fallback_or_error(
guided_params,
"xgrammar does not support regex guided decoding.", "outlines")
# xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None
and has_xgrammar_unsupported_json_features(guided_params.json)):
if (guided_params.json is not None and
has_xgrammar_unsupported_json_features(guided_params.json)):
fallback_or_error(
guided_params,
"xgrammar does not support advanced JSON schema features like "
......
......@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if "pattern" in obj:
return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
......
......@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List
import torch
import vllm.envs
from vllm.logger import init_logger
try:
......@@ -131,8 +132,13 @@ class GrammarCompilerCache:
encoded_vocab=config_data.encoded_vocab,
metadata=config_data.metadata,
)
cache_size = vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024
cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=config.max_threads)
tokenizer_info,
max_threads=config.max_threads,
cache_enabled=True,
cache_limit_bytes=cache_size,
)
return cls._cache[cache_key]
......@@ -146,6 +152,7 @@ class GrammarConfig:
grammar_str: str | None = None
json_object: bool | None = None
any_whitespace: bool = True
regex_str: str | None = None
max_threads: int = 8
@classmethod
......@@ -249,6 +256,13 @@ class GrammarConfig:
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
elif guided_params.regex:
return cls(
regex_str=guided_params.regex,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
else:
raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar"
......@@ -324,6 +338,8 @@ class XGrammarLogitsProcessor:
self.ctx = compiler\
.compile_json_schema('{"type": "object"}',
any_whitespace=any_whitespace)
elif self.config.regex_str:
self.ctx = compiler.compile_regex(self.config.regex_str)
else:
raise ValueError(
"Invalid configuration for xgrammar logits processor")
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
......@@ -16,7 +16,10 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
......@@ -251,50 +254,53 @@ def fused_moe_kernel_gptq_awq(
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr):
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
......@@ -385,12 +391,23 @@ def fused_moe_kernel(
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8:
if use_fp8_w8a8 or use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse +
offs_bsn * stride_bsn)
# channel-wise
elif per_channel_quant:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:,
None]
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
......@@ -414,7 +431,7 @@ def fused_moe_kernel(
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8:
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
......@@ -426,7 +443,11 @@ def fused_moe_kernel(
accumulator += tl.dot(a, b) * a_scale[:,
None] * b_scale[None, :]
else:
accumulator = tl.dot(a, b, acc=accumulator)
if use_fp8_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
......@@ -440,7 +461,7 @@ def fused_moe_kernel(
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8:
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
......@@ -471,28 +492,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config: Dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.shape[0]
num_tokens = M * top_k
......@@ -619,7 +628,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
top_k=top_k,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
BLOCK_SIZE_K=BLOCK_SIZE_K,
**config,
)
......@@ -981,8 +992,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
activation: Optional[str] = None,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -995,9 +1008,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, use_nn_moe)
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
def inplace_fused_experts_fake(
......@@ -1009,8 +1023,10 @@ def inplace_fused_experts_fake(
activation: Optional[str] = None,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1029,6 +1045,7 @@ direct_register_custom_op(
op_func=inplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
......@@ -1041,8 +1058,10 @@ def outplace_fused_experts(
activation: Optional[str] = None,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1055,7 +1074,8 @@ def outplace_fused_experts(
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, per_channel_quant,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
......@@ -1069,8 +1089,10 @@ def outplace_fused_experts_fake(
topk_ids: torch.Tensor,
activation: Optional[str] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1089,6 +1111,7 @@ direct_register_custom_op(
op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
......@@ -1119,8 +1142,10 @@ def fused_experts(hidden_states: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1160,8 +1185,10 @@ def fused_experts(hidden_states: torch.Tensor,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
......@@ -1174,6 +1201,59 @@ def fused_experts(hidden_states: torch.Tensor,
use_nn_moe=use_nn_moe)
def moe_kernel_prepare_input(
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
# activation channel-wise int8 quantization
assert (per_channel_quant
), "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
return A, A_scale
def fused_experts_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
......@@ -1183,8 +1263,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1294,14 +1376,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qcurr_hidden_states, a1q_scale = _fp8_quantize(
curr_hidden_states, a1_scale, block_shape)
else:
qcurr_hidden_states = curr_hidden_states
a1q_scale = a1_scale
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
A=curr_hidden_states,
B=w1,
A_scale=a1_scale,
B_scale=w1_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
......@@ -1310,7 +1395,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel(qcurr_hidden_states,
w1,
intermediate_cache1,
a1q_scale,
qa1_scale,
w1_scale,
w1_zp,
curr_topk_weights,
......@@ -1322,8 +1407,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
......@@ -1336,19 +1423,22 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qintermediate_cache2, a2q_scale = _fp8_quantize(
intermediate_cache2, a2_scale, block_shape)
else:
qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale
qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
A=intermediate_cache2,
B=w2,
A_scale=a2_scale,
B_scale=w2_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
qa2_scale,
w2_scale,
w2_zp,
curr_topk_weights,
......@@ -1360,8 +1450,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
......@@ -1385,8 +1477,10 @@ def fused_moe(
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
......@@ -1420,6 +1514,8 @@ def fused_moe(
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
......@@ -1466,8 +1562,10 @@ def fused_moe(
inplace=inplace,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
......
......@@ -192,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
apply_router_weight_on_input=apply_router_weight_on_input
use_nn_moe=use_nn_moe)
def forward_cuda(
......@@ -458,7 +458,7 @@ class FusedMoE(torch.nn.Module):
# Use expert parallelism instead of tensor parallelism?
vllm_config = get_current_vllm_config()
use_ep = (vllm_config.parallel_config.enable_expert_parallel
and self.tp_size > 1)
and self.tp_size * self.dp_size > 1)
# For smuggling this layer into the fused moe custom op
self.use_direct_call = self.dp_size == 1
......@@ -542,7 +542,9 @@ class FusedMoE(torch.nn.Module):
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
......@@ -691,9 +693,10 @@ class FusedMoE(torch.nn.Module):
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = loaded_weight.t().contiguous() if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
if self.quant_method.__class__.__name__ in (
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod"):
loaded_weight = loaded_weight.t().contiguous()
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
......
......@@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
prefix=f"{prefix}.kv_proj_encoder")
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.q_size = self.q_proj_decoder.output_size_per_partition
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
if bias:
......@@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
else:
self.bias = None
def process_weights_after_loading(self):
for layer in self.proj.values():
if self.quant_method is not None:
self.quant_method.process_weights_after_loading(layer)
@property
def q_proj_decoder(self) -> ColumnParallelLinear:
layer = self.proj["q_proj_decoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
target_param = getattr(layer, name, None)
if target_param is not None:
self.sync_weight_attrs(param,
target_param,
mode="q_proj_decoder")
return layer
@property
def kv_proj_encoder(self) -> QKVParallelLinear:
layer = self.proj["kv_proj_encoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
target_param = getattr(layer, name, None)
if target_param is not None:
self.sync_weight_attrs(param,
target_param,
mode="kv_proj_encoder")
return layer
def sync_weight_attrs(
......@@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
else:
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
s += f", q_size={self.q_size}"
s += f", kv_size={self.kv_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment