Commit 899a2db4 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1(ex fused_moe&models)

parent 78c1f9e5
......@@ -63,7 +63,6 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
make_tool_call_id,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import ToolServer
......@@ -116,7 +115,6 @@ from vllm.entrypoints.openai.responses.utils import (
extract_tool_types,
should_continue_final_message,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
......@@ -252,17 +250,6 @@ class OpenAIServingResponses(OpenAIServing):
self.default_sampling_params["stop_token_ids"].extend(
get_stop_tokens_for_assistant_actions()
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.enable_auto_tools = enable_auto_tools
# set up tool use
self.tool_parser = self._get_tool_parser(
......@@ -436,11 +423,8 @@ class OpenAIServingResponses(OpenAIServing):
if maybe_error is not None:
return maybe_error
default_max_tokens = get_max_tokens(
self.max_model_len,
request,
engine_prompt,
self.default_sampling_params,
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
sampling_params = request.to_sampling_params(
......@@ -970,28 +954,25 @@ class OpenAIServingResponses(OpenAIServing):
enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser,
)
if content or (self.use_harmony and tool_calls):
res_text_part = None
if content:
res_text_part = ResponseOutputText(
text=content,
annotations=[], # TODO
type="output_text",
logprobs=(
self._create_response_logprobs(
token_ids=final_output.token_ids,
logprobs=final_output.logprobs,
tokenizer=tokenizer,
top_logprobs=request.top_logprobs,
)
if request.is_include_output_logprobs()
else None
),
)
if content:
output_text = ResponseOutputText(
text=content,
annotations=[], # TODO
type="output_text",
logprobs=(
self._create_response_logprobs(
token_ids=final_output.token_ids,
logprobs=final_output.logprobs,
tokenizer=tokenizer,
top_logprobs=request.top_logprobs,
)
if request.is_include_output_logprobs()
else None
),
)
message_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[res_text_part] if res_text_part else [],
content=[output_text],
role="assistant",
status="completed",
type="message",
......@@ -1003,28 +984,17 @@ class OpenAIServingResponses(OpenAIServing):
if message_item:
outputs.append(message_item)
if tool_calls:
# We use a simple counter for history_tool_call_count because
# we don't track the history of tool calls in the Responses API yet.
# This means that the tool call index will start from 0 for each
# request.
tool_call_items = []
for history_tool_call_cnt, tool_call in enumerate(tool_calls):
tool_call_items.append(
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=tool_call.id
if tool_call.id
else make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
type="function_call",
status="completed",
name=tool_call.name,
arguments=tool_call.arguments,
)
tool_call_items = [
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}",
type="function_call",
status="completed",
name=tool_call.name,
arguments=tool_call.arguments,
)
for tool_call in tool_calls
]
outputs.extend(tool_call_items)
return outputs
......@@ -2589,4 +2559,4 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number=-1,
response=final_response,
)
)
)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from random import choices
from string import ascii_letters, digits
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from pydantic import Field
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
class MistralToolCall(ToolCall):
id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id())
@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))
@staticmethod
def is_valid_id(id: str) -> bool:
return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
return (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
)
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
- the examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral model...")
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if _is_fn_name_regex_support(self.model_tokenizer):
self.fn_name_regex = re.compile(
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
)
else:
self.fn_name_regex = None
if self.bot_token_id is None:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!"
)
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if (
not isinstance(self.model_tokenizer, MistralTokenizer)
and request.tools
and request.tool_choice != "none"
):
# Do not skip special tokens when using chat template
# with Mistral parser as TOOL_CALL token is needed
# for tool detection.
# Note: we don't want skip_special_tokens=False
# with MistralTokenizer as it is incompatible
request.skip_special_tokens = False
return request
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes!
"""
# case -- if a tool call token is not present, return a text response
if self.bot_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# first remove the BOT token
tool_content = model_output.replace(self.bot_token, "").strip()
try:
# we first try to directly load the json as parsing very nested
# jsons is difficult
try:
if self.fn_name_regex:
matches = self.fn_name_regex.findall(tool_content)
function_call_arr = []
for match in matches:
fn_name = match[0]
args = match[1]
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)}
)
else:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's an easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
function_call_arr = json.loads(raw_tool_call)
# Tool Call
tool_calls: list[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"], ensure_ascii=False
),
),
)
for raw_function_call in function_call_arr
]
# get any content before the tool call
content = model_output.split(self.bot_token)[0]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if len(content) > 0 else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=tool_content
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.bot_token not in current_text:
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the BOT token which means the start of tool
# calling
if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1:
# if it's the only token, return None, so we don't send a chat
# completion any don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr = current_text.split(self.bot_token)[-1]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags
)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("not enough tokens to parse into JSON yet")
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: str | None = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id], ""
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=MistralToolCall.generate_random_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("'", '"')
if '"}' in new_text:
new_text = new_text[: new_text.rindex('"}')]
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset mid-arguments"
)
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)[
:-2
]
logger.debug("finding %s in %s", new_text, cur_arguments_json)
if new_text not in cur_arguments_json:
return None
arguments_delta = cur_arguments_json[
: cur_arguments_json.rindex(new_text) + len(new_text)
]
logger.debug(
"First tokens in arguments received: %s", arguments_delta
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
logger.debug(
"Searching for diff between \n%s\n%s",
cur_args_json,
prev_args_json,
)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json
)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import Final, cast
from typing import cast
import jinja2
import numpy as np
......@@ -11,8 +11,18 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
ClassificationServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
......@@ -29,68 +39,60 @@ from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
ClassificationServeContext = ServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
class ClassificationMixin(OpenAIServing):
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
trust_request_chat_template: bool
async def _preprocess(
self,
ctx: ClassificationServeContext,
ctx: ServeContext,
) -> ErrorResponse | None:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
ctx = cast(ClassificationServeContext, ctx)
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, ClassificationChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
request_obj = ctx.request
if isinstance(request_obj, ClassificationChatRequest):
chat_request = request_obj
messages = chat_request.messages
trust_request_chat_template = getattr(
self,
"trust_request_chat_template",
False,
)
if error_check_ret:
return error_check_ret
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
)
if ret:
return ret
_, engine_prompts = await self._preprocess_chat(
ctx.request,
cast(ChatCompletionRequest, chat_request),
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
messages,
chat_template=(
chat_request.chat_template
or getattr(self, "chat_template", None)
),
chat_template_content_format=cast(
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
)
ctx.engine_prompts = engine_prompts
elif isinstance(ctx.request, ClassificationCompletionRequest):
input_data = ctx.request.input
elif isinstance(request_obj, ClassificationCompletionRequest):
completion_request = request_obj
input_data = completion_request.input
if input_data in (None, ""):
return self.create_error_response(
"Input or messages must be provided",
......@@ -104,10 +106,13 @@ class ServingClassification(OpenAIServing):
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(ctx.request),
config=self._build_render_config(completion_request),
)
else:
return self.create_error_response("Invalid classification request type")
return self.create_error_response(
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return None
......@@ -117,14 +122,13 @@ class ServingClassification(OpenAIServing):
def _build_response(
self,
ctx: ClassificationServeContext,
ctx: ServeContext,
) -> ClassificationResponse | ErrorResponse:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
id2label = getattr(self.model_config.hf_config, "id2label", {})
ctx = cast(ClassificationServeContext, ctx)
items: list[ClassificationData] = []
num_prompt_tokens = 0
......@@ -135,7 +139,9 @@ class ServingClassification(OpenAIServing):
probs = classify_res.probs
predicted_index = int(np.argmax(probs))
label = id2label.get(predicted_index)
label = getattr(self.model_config.hf_config, "id2label", {}).get(
predicted_index
)
item = ClassificationData(
index=idx,
......@@ -168,6 +174,32 @@ class ServingClassification(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify(
self,
request: ClassificationRequest,
......@@ -183,11 +215,11 @@ class ServingClassification(OpenAIServing):
request_id=request_id,
)
return await self.handle(ctx) # type: ignore[return-value]
return await super().handle(ctx) # type: ignore
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
ctx: ServeContext[ClassificationRequest],
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
......@@ -198,4 +230,4 @@ class ServingClassification(OpenAIServing):
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params
return pooling_params
\ No newline at end of file
......@@ -6,13 +6,21 @@ from typing import Any, Final, cast
import torch
from fastapi import Request
from typing_extensions import assert_never
from fastapi.responses import Response
from typing_extensions import assert_never, override
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
......@@ -25,11 +33,19 @@ from vllm.entrypoints.pooling.embed.protocol import (
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.outputs import (
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes,
encode_pooling_output,
)
......@@ -37,33 +53,9 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__)
EmbeddingServeContext = ServeContext[EmbeddingRequest]
class OpenAIServingEmbedding(OpenAIServing):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
class EmbeddingMixin(OpenAIServing):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pooler_config = self.model_config.pooler_config
......@@ -77,41 +69,32 @@ class OpenAIServingEmbedding(OpenAIServing):
else None
)
@override
async def _preprocess(
self,
ctx: EmbeddingServeContext,
ctx: ServeContext,
) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
chat_template=ctx.request.chat_template or ctx.chat_template,
chat_template_content_format=ctx.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
)
elif isinstance(ctx.request, EmbeddingCompletionRequest):
else:
renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),
)
else:
return self.create_error_response("Invalid classification request type")
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
......@@ -130,15 +113,16 @@ class OpenAIServingEmbedding(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
@override
def _build_response(
self,
ctx: EmbeddingServeContext,
) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
final_res_batch_checked = ctx.final_res_batch
ctx: ServeContext,
) -> EmbeddingResponse | Response | ErrorResponse:
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
encoding_format = ctx.request.encoding_format
embed_dtype = ctx.request.embed_dtype
endianness = ctx.request.endianness
encoding_format: EncodingFormat = ctx.request.encoding_format
embed_dtype: EmbedDType = ctx.request.embed_dtype
endianness: Endianness = ctx.request.endianness
def encode_float_base64():
items: list[EmbeddingResponseData] = []
......@@ -219,8 +203,8 @@ class OpenAIServingEmbedding(OpenAIServing):
self,
ctx: EmbeddingServeContext,
token_ids: list[int],
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
pooling_params,
trace_headers,
prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing."""
......@@ -262,7 +246,7 @@ class OpenAIServingEmbedding(OpenAIServing):
def _validate_input(
self,
request: object,
request,
input_ids: list[int],
input_text: str,
) -> TokensPrompt:
......@@ -342,7 +326,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
) -> AsyncGenerator[PoolingRequestOutput, None]:
) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
"""Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}"
......@@ -363,6 +347,7 @@ class OpenAIServingEmbedding(OpenAIServing):
priority=getattr(ctx.request, "priority", 0),
)
@override
async def _prepare_generators(
self,
ctx: ServeContext,
......@@ -378,7 +363,9 @@ class OpenAIServingEmbedding(OpenAIServing):
return await super()._prepare_generators(ctx)
# Custom logic for chunked processing
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try:
trace_headers = (
......@@ -432,9 +419,10 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@override
async def _collect_batch(
self,
ctx: EmbeddingServeContext,
ctx: ServeContext,
) -> ErrorResponse | None:
"""Collect and aggregate batch results
with support for chunked processing.
......@@ -443,6 +431,7 @@ class OpenAIServingEmbedding(OpenAIServing):
minimize memory usage.
For regular requests, collects results normally.
"""
ctx = cast(EmbeddingServeContext, ctx)
try:
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
......@@ -538,10 +527,12 @@ class OpenAIServingEmbedding(OpenAIServing):
except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = result
short_prompts_results[prompt_idx] = cast(
PoolingRequestOutput, result
)
# Finalize aggregated results
final_res_batch: list[PoolingRequestOutput] = []
final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts):
......@@ -589,19 +580,49 @@ class OpenAIServingEmbedding(OpenAIServing):
f"Failed to aggregate chunks for prompt {prompt_idx}"
)
elif prompt_idx in short_prompts_results:
final_res_batch.append(short_prompts_results[prompt_idx])
final_res_batch.append(
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
)
else:
return self.create_error_response(
f"Result not found for prompt {prompt_idx}"
)
ctx.final_res_batch = final_res_batch
ctx.final_res_batch = cast(
list[RequestOutput | PoolingRequestOutput], final_res_batch
)
return None
except Exception as e:
return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding(
self,
request: EmbeddingRequest,
......@@ -624,13 +645,16 @@ class OpenAIServingEmbedding(OpenAIServing):
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
return await self.handle(ctx) # type: ignore[return-value]
return await super().handle(ctx) # type: ignore
@override
def _create_pooling_params(
self,
ctx: EmbeddingServeContext,
ctx: ServeContext[EmbeddingRequest],
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
......@@ -642,3 +666,17 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e))
return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)
\ No newline at end of file
......@@ -17,10 +17,8 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
......@@ -34,15 +32,11 @@ if TYPE_CHECKING:
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
else:
ChatCompletionRequest = object
CompletionRequest = object
StreamOptions = object
LoRAModulePath = object
ResponsesRequest = object
logger = init_logger(__name__)
......@@ -217,26 +211,11 @@ def _validate_truncation_size(
def get_max_tokens(
max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
prompt: TokensPrompt | EmbedsPrompt,
request: "ChatCompletionRequest | CompletionRequest",
input_length: int,
default_sampling_params: dict,
) -> int:
# NOTE: Avoid isinstance() for better efficiency
max_tokens: int | None = None
if max_tokens is None:
# ChatCompletionRequest
max_tokens = getattr(request, "max_completion_tokens", None)
if max_tokens is None:
# ResponsesRequest
max_tokens = getattr(request, "max_output_tokens", None)
if max_tokens is None:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)
......@@ -343,4 +322,4 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None:
message = logo_template.substitute(colors)
lgr.info(message, version, model_name)
lgr.info(message, version, model_name)
\ No newline at end of file
......@@ -87,7 +87,6 @@ if TYPE_CHECKING:
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
VLLM_PLUGINS: list[str] | None = None
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None
# Deprecated env variables for profiling, kept for backward compatibility
# See also vllm/config/profiler.py and `--profiler-config` argument
VLLM_TORCH_CUDA_PROFILE: str | None = None
......@@ -327,11 +326,16 @@ def use_aot_compile() -> bool:
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = (
"1"
if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache()
if is_torch_equal_or_newer("2.10.0.dev")
and not disable_compile_cache()
# Disabling AOT_COMPILE for CPU
# See: https://github.com/vllm-project/vllm/issues/32033
and not current_platform.is_cpu()
else "0"
)
......@@ -912,13 +916,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
"VLLM_LORA_RESOLVER_CACHE_DIR", None
),
# A remote HF repo(s) containing one or more LoRA adapters, which
# may be downloaded and leveraged as needed. Only works if plugins
# are enabled and VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled.
# Values should be comma separated.
"VLLM_LORA_RESOLVER_HF_REPO_LIST": lambda: os.getenv(
"VLLM_LORA_RESOLVER_HF_REPO_LIST", None
),
# Enables torch CUDA profiling if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time
from collections import defaultdict
from contextlib import contextmanager
......@@ -13,7 +12,6 @@ import torch
import vllm.envs as envs
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
......@@ -428,15 +426,3 @@ def set_forward_context(
),
forward_stats,
)
_profiling: bool = False
@contextmanager
def set_profilling(profiling):
global _profiling
_profiling = profiling
def get_profilling() -> bool:
global _profiling
return _profiling
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logging_utils.access_log_filter import (
UvicornAccessLogFilter,
create_uvicorn_log_config,
)
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
from vllm.logging_utils.lazy import lazy
from vllm.logging_utils.log_time import logtime
......@@ -12,8 +8,6 @@ from vllm.logging_utils.log_time import logtime
__all__ = [
"NewLineFormatter",
"ColoredFormatter",
"UvicornAccessLogFilter",
"create_uvicorn_log_config",
"lazy",
"logtime",
]
]
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Access log filter for uvicorn to exclude specific endpoints from logging.
This module provides a logging filter that can be used to suppress access logs
for specific endpoints (e.g., /health, /metrics) to reduce log noise in
production environments.
"""
import logging
from urllib.parse import urlparse
class UvicornAccessLogFilter(logging.Filter):
"""
A logging filter that excludes access logs for specified endpoint paths.
This filter is designed to work with uvicorn's access logger. It checks
the log record's arguments for the request path and filters out records
matching the excluded paths.
Uvicorn access log format:
'%s - "%s %s HTTP/%s" %d'
(client_addr, method, path, http_version, status_code)
Example:
127.0.0.1:12345 - "GET /health HTTP/1.1" 200
Args:
excluded_paths: A list of URL paths to exclude from logging.
Paths are matched exactly.
Example: ["/health", "/metrics"]
"""
def __init__(self, excluded_paths: list[str] | None = None):
super().__init__()
self.excluded_paths = set(excluded_paths or [])
def filter(self, record: logging.LogRecord) -> bool:
"""
Determine if the log record should be logged.
Args:
record: The log record to evaluate.
Returns:
True if the record should be logged, False otherwise.
"""
if not self.excluded_paths:
return True
# This filter is specific to uvicorn's access logs.
if record.name != "uvicorn.access":
return True
# The path is the 3rd argument in the log record's args tuple.
# See uvicorn's access logging implementation for details.
log_args = record.args
if isinstance(log_args, tuple) and len(log_args) >= 3:
path_with_query = log_args[2]
# Get path component without query string.
if isinstance(path_with_query, str):
path = urlparse(path_with_query).path
if path in self.excluded_paths:
return False
return True
def create_uvicorn_log_config(
excluded_paths: list[str] | None = None,
log_level: str = "info",
) -> dict:
"""
Create a uvicorn logging configuration with access log filtering.
This function generates a logging configuration dictionary that can be
passed to uvicorn's `log_config` parameter. It sets up the access log
filter to exclude specified paths.
Args:
excluded_paths: List of URL paths to exclude from access logs.
log_level: The log level for uvicorn loggers.
Returns:
A dictionary containing the logging configuration.
Example:
>>> config = create_uvicorn_log_config(["/health", "/metrics"])
>>> uvicorn.run(app, log_config=config)
"""
config = {
"version": 1,
"disable_existing_loggers": False,
"filters": {
"access_log_filter": {
"()": UvicornAccessLogFilter,
"excluded_paths": excluded_paths or [],
},
},
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(levelprefix)s %(message)s",
"use_colors": None,
},
"access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s', # noqa: E501
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"access": {
"formatter": "access",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"filters": ["access_log_filter"],
},
},
"loggers": {
"uvicorn": {
"handlers": ["default"],
"level": log_level.upper(),
"propagate": False,
},
"uvicorn.error": {
"level": log_level.upper(),
"handlers": ["default"],
"propagate": False,
},
"uvicorn.access": {
"handlers": ["access"],
"level": log_level.upper(),
"propagate": False,
},
},
}
return config
......@@ -62,7 +62,6 @@ def _fused_moe_lora_kernel(
num_experts,
lora_ids,
adapter_enabled,
max_loras, # <<< PR2: rename, used for masks when grid axis-2 != max_loras
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
......@@ -84,7 +83,6 @@ def _fused_moe_lora_kernel(
num_slice_c: tl.constexpr,
top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
......@@ -106,13 +104,10 @@ def _fused_moe_lora_kernel(
if moe_enabled == 0:
# Early exit for the no moe lora case.
return
# The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel.
# This guard ensures we don't access sorted_token_ids / expert_ids /
# num_tokens_post_padded beyond their allocated bounds if an invalid
# lora_id somehow appears. Although the caller should pass correct
# max_loras, defensive programming prevents accidental out-of-bounds.
if lora_id >= max_loras:
return
# The grid size on axis 2 is (max_loras + 1) to handle the no-lora case
# (lora_id == -1), but sorted_token_ids and expert_ids are allocated with
# shape (max_loras, ...). Use (num_programs - 1) for correct bounds checking.
max_loras = tl.num_programs(axis=2) - 1
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n
......@@ -141,11 +136,10 @@ def _fused_moe_lora_kernel(
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
# remove modulo wrap-around
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int32)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load(
sorted_token_ids_ptr + token_ind,
......@@ -182,13 +176,7 @@ def _fused_moe_lora_kernel(
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
# pre-fetch lora weight
# add (offs_bn < N) mask; optional .ca for B
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
if USE_B_L2_CACHE:
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
else:
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
a = tl.load(
......@@ -288,7 +276,6 @@ def _fused_moe_lora_shrink(
num_experts,
lora_ids,
adapter_enabled,
lora_a_stacked[0].shape[0],
qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0),
......@@ -305,7 +292,6 @@ def _fused_moe_lora_shrink(
num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=True,
**shrink_config,
)
......@@ -391,7 +377,6 @@ def _fused_moe_lora_expand(
num_experts,
lora_ids,
adapter_enabled,
lora_b_stacked[0].shape[0],
a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0),
......@@ -408,7 +393,6 @@ def _fused_moe_lora_expand(
num_slice_c=num_slices,
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=False,
**expand_config,
)
......@@ -697,4 +681,4 @@ try:
except AttributeError:
fused_moe_lora = _fused_moe_lora
fused_moe_lora_shrink = _fused_moe_lora_shrink
fused_moe_lora_expand = _fused_moe_lora_expand
fused_moe_lora_expand = _fused_moe_lora_expand
\ No newline at end of file
......@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE):
use_overlapped
and not (
(self.enable_eplb and backend != "allgather_reducescatter")
or self.moe_parallel_config.use_fi_all2allv_kernels
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and self._shared_experts is not None
)
......@@ -93,4 +93,4 @@ class SharedFusedMoE(FusedMoE):
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out
return shared_out, fused_out
\ No newline at end of file
......@@ -193,7 +193,6 @@ class RMSNorm(CustomOp):
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype)
if weight is not None:
......
......@@ -380,7 +380,6 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
eps: float | None = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
......@@ -391,8 +390,6 @@ class ReplicatedLinear(LinearBase):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]
self.eps = eps
super().__init__(
input_size,
......@@ -643,7 +640,6 @@ class ColumnParallelLinear(LinearBase):
if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......@@ -720,13 +716,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
eps: float | None = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.eps = eps
self.output_sizes = output_sizes
self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
......@@ -1366,7 +1360,6 @@ class QKVParallelLinear(ColumnParallelLinear):
if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......@@ -1475,7 +1468,6 @@ class RowParallelLinear(LinearBase):
)
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
......@@ -1516,7 +1508,6 @@ class RowParallelLinear(LinearBase):
if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......
......@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
......@@ -503,9 +502,6 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1,
)
# Check if running on Blackwell (SM100+) for kernel tuning
self.is_blackwell = current_platform.is_device_capability_family(100)
def forward_native(
self,
hidden_states: torch.Tensor,
......@@ -887,7 +883,6 @@ class MambaMixer2(MambaBase, CustomOp):
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
is_blackwell=self.is_blackwell,
)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
......@@ -938,4 +933,4 @@ direct_register_custom_op(
op_func=mamba_mixer2,
mutates_args=["output"],
fake_impl=mamba_mixer2_fake,
)
)
\ No newline at end of file
......@@ -286,7 +286,6 @@ def selective_state_update(
out=None,
num_accepted_tokens=None,
cu_seqlens=None,
is_blackwell=False,
):
"""
Argument:
......@@ -392,26 +391,17 @@ def selective_state_update(
if dst_state_batch_indices is not None
else (0, 0)
)
# We don't want autotune since it will overwrite the state.
# We instead tune by hand based on dstate.
# Default
BLOCK_SIZE_M, num_warps = 4, 8
if dstate <= 16:
BLOCK_SIZE_M, num_warps = 32, 4
elif dstate <= 32:
BLOCK_SIZE_M, num_warps = 16, 4
elif dstate <= 64:
BLOCK_SIZE_M, num_warps = 8, 4
else:
# dstate > 64
if is_blackwell:
# Optimized for B200 with dstate>64
BLOCK_SIZE_M, num_warps = 32, 8
elif dstate <= 128:
BLOCK_SIZE_M, num_warps = 4, 4
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = (
(32, 4)
if dstate <= 16
else (
(16, 4)
if dstate <= 32
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
)
)
tie_hdim = (
A.stride(-1) == 0
and A.stride(-2) == 0
......@@ -593,4 +583,4 @@ def selective_scan_fn(
if z is None:
return delta # output written inplace to delta
else:
return z # output written inplace to z
return z # output written inplace to z
\ No newline at end of file
......@@ -188,6 +188,7 @@ class CompressedTensorsConfig(QuantizationConfig):
else:
return quant_method
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
......
......@@ -42,6 +42,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
......@@ -51,6 +52,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
......@@ -64,6 +66,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe,
......@@ -95,7 +98,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils import W8a8GetCacheJSON
logger = init_logger(__name__)
......@@ -240,6 +242,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights(
self,
......@@ -316,7 +319,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
)
def process_weights_after_loading(self, layer: FusedMoE) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False
)
......@@ -331,12 +334,10 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
self.moe_mk = make_nvfp4_moe_kernel(
self.kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
def apply(
......@@ -346,8 +347,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk(
assert self.kernel is not None
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -378,10 +379,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
activation_key=None if use_a16 else kNvfp4Dynamic,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights(
self,
layer: torch.nn.Module,
......@@ -495,7 +505,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: FusedMoE) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
......@@ -561,33 +571,48 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel(
self.kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
assert self.moe_quant_config is not None
assert self.experts_cls is not None
return make_nvfp4_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
)
def get_fused_moe_quant_config(
......@@ -658,8 +683,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
)
else:
assert self.moe_mk is not None
return self.moe_mk(
assert self.kernel is not None
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -733,6 +758,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
allow_vllm_cutlass=True,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights(
self,
layer: torch.nn.Module,
......@@ -891,28 +925,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: FusedMoE) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
pass
def process_weights_after_loading(self, layer: FusedMoE) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight
w2 = layer.w2_weight
......@@ -974,34 +988,49 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
self.kernel, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
assert self.experts_cls is not None
return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
)
def get_fused_moe_quant_config(
......@@ -1080,12 +1109,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.moe_mk is not None
return self.moe_mk(
assert self.kernel is not None
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -1134,7 +1161,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(
self,
......@@ -1203,22 +1229,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
pass
def get_fused_moe_quant_config(
......@@ -1238,8 +1248,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -1255,8 +1263,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
use_fused_gate=use_fused_gate,
use_nn_moe=False,
)
......@@ -1869,7 +1875,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
return True
# TODO @gaoqiong
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
"""
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
......@@ -2497,4 +2502,4 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
@property
def supports_eplb(self) -> bool:
return False
return False
\ No newline at end of file
......@@ -16,9 +16,6 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale,
)
from vllm.model_executor.parameter import (
......@@ -162,20 +159,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
if self.backend == "fbgemm":
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N
# divisible by 32). fbgemm has its own layout requirements.
if self.backend in ("cutlass", "flashinfer-cutlass"):
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight_packed.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
layer.weights_padding_cols = 0
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False
)
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False
)
layer.alpha = Parameter(
1 / (layer.input_global_scale * layer.weight_global_scale),
......@@ -201,8 +187,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
return out
output_dtype = x.dtype
output_size = layer.output_size_per_partition
output_shape = [*x.shape[:-1], output_size]
output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(
......@@ -212,10 +197,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
backend=self.backend,
)
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = (
x_fp4,
layer.weight_packed,
......@@ -240,9 +221,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
return out.view(*output_shape)
\ No newline at end of file
......@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
......@@ -51,6 +52,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
......@@ -676,6 +678,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_vllm_cutlass=False,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights(
self,
layer: Module,
......@@ -801,7 +812,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel(
self,
layer: FusedMoE,
layer: Module,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
......@@ -833,15 +844,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
self.kernel, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer: Module) -> None:
......@@ -896,19 +908,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
assert self.moe_quant_config is not None
assert self.experts_cls is not None
return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
)
def get_fused_moe_quant_config(
......@@ -948,7 +974,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,**_,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
......@@ -1002,9 +1028,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
assert self.kernel is not None
assert not self.is_monolithic
return self.moe_mk(
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -1164,4 +1190,4 @@ class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)
super().__init__(quant_config)
\ No newline at end of file
......@@ -304,37 +304,6 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment