Commit 899a2db4 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1(ex fused_moe&models)

parent 78c1f9e5
...@@ -63,7 +63,6 @@ from vllm.engine.protocol import EngineClient ...@@ -63,7 +63,6 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
make_tool_call_id,
) )
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.mcp.tool_server import ToolServer
...@@ -116,7 +115,6 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -116,7 +115,6 @@ from vllm.entrypoints.openai.responses.utils import (
extract_tool_types, extract_tool_types,
should_continue_final_message, should_continue_final_message,
) )
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -252,17 +250,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -252,17 +250,6 @@ class OpenAIServingResponses(OpenAIServing):
self.default_sampling_params["stop_token_ids"].extend( self.default_sampling_params["stop_token_ids"].extend(
get_stop_tokens_for_assistant_actions() get_stop_tokens_for_assistant_actions()
) )
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.enable_auto_tools = enable_auto_tools self.enable_auto_tools = enable_auto_tools
# set up tool use # set up tool use
self.tool_parser = self._get_tool_parser( self.tool_parser = self._get_tool_parser(
...@@ -436,11 +423,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -436,11 +423,8 @@ class OpenAIServingResponses(OpenAIServing):
if maybe_error is not None: if maybe_error is not None:
return maybe_error return maybe_error
default_max_tokens = get_max_tokens( default_max_tokens = self.max_model_len - len(
self.max_model_len, engine_prompt["prompt_token_ids"]
request,
engine_prompt,
self.default_sampling_params,
) )
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
...@@ -970,11 +954,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -970,11 +954,8 @@ class OpenAIServingResponses(OpenAIServing):
enable_auto_tools=self.enable_auto_tools, enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser, tool_parser_cls=self.tool_parser,
) )
if content or (self.use_harmony and tool_calls):
res_text_part = None
if content: if content:
res_text_part = ResponseOutputText( output_text = ResponseOutputText(
text=content, text=content,
annotations=[], # TODO annotations=[], # TODO
type="output_text", type="output_text",
...@@ -991,7 +972,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -991,7 +972,7 @@ class OpenAIServingResponses(OpenAIServing):
) )
message_item = ResponseOutputMessage( message_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}", id=f"msg_{random_uuid()}",
content=[res_text_part] if res_text_part else [], content=[output_text],
role="assistant", role="assistant",
status="completed", status="completed",
type="message", type="message",
...@@ -1003,28 +984,17 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1003,28 +984,17 @@ class OpenAIServingResponses(OpenAIServing):
if message_item: if message_item:
outputs.append(message_item) outputs.append(message_item)
if tool_calls: if tool_calls:
# We use a simple counter for history_tool_call_count because tool_call_items = [
# we don't track the history of tool calls in the Responses API yet.
# This means that the tool call index will start from 0 for each
# request.
tool_call_items = []
for history_tool_call_cnt, tool_call in enumerate(tool_calls):
tool_call_items.append(
ResponseFunctionToolCall( ResponseFunctionToolCall(
id=f"fc_{random_uuid()}", id=f"fc_{random_uuid()}",
call_id=tool_call.id call_id=f"call_{random_uuid()}",
if tool_call.id
else make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
type="function_call", type="function_call",
status="completed", status="completed",
name=tool_call.name, name=tool_call.name,
arguments=tool_call.arguments, arguments=tool_call.arguments,
) )
) for tool_call in tool_calls
]
outputs.extend(tool_call_items) outputs.extend(tool_call_items)
return outputs return outputs
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from random import choices
from string import ascii_letters, digits
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from pydantic import Field
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
class MistralToolCall(ToolCall):
id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id())
@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))
@staticmethod
def is_valid_id(id: str) -> bool:
return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
return (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
)
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
- the examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral model...")
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if _is_fn_name_regex_support(self.model_tokenizer):
self.fn_name_regex = re.compile(
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
)
else:
self.fn_name_regex = None
if self.bot_token_id is None:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!"
)
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if (
not isinstance(self.model_tokenizer, MistralTokenizer)
and request.tools
and request.tool_choice != "none"
):
# Do not skip special tokens when using chat template
# with Mistral parser as TOOL_CALL token is needed
# for tool detection.
# Note: we don't want skip_special_tokens=False
# with MistralTokenizer as it is incompatible
request.skip_special_tokens = False
return request
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes!
"""
# case -- if a tool call token is not present, return a text response
if self.bot_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# first remove the BOT token
tool_content = model_output.replace(self.bot_token, "").strip()
try:
# we first try to directly load the json as parsing very nested
# jsons is difficult
try:
if self.fn_name_regex:
matches = self.fn_name_regex.findall(tool_content)
function_call_arr = []
for match in matches:
fn_name = match[0]
args = match[1]
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)}
)
else:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's an easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
function_call_arr = json.loads(raw_tool_call)
# Tool Call
tool_calls: list[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"], ensure_ascii=False
),
),
)
for raw_function_call in function_call_arr
]
# get any content before the tool call
content = model_output.split(self.bot_token)[0]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if len(content) > 0 else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=tool_content
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.bot_token not in current_text:
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the BOT token which means the start of tool
# calling
if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1:
# if it's the only token, return None, so we don't send a chat
# completion any don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr = current_text.split(self.bot_token)[-1]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags
)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("not enough tokens to parse into JSON yet")
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: str | None = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id], ""
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=MistralToolCall.generate_random_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("'", '"')
if '"}' in new_text:
new_text = new_text[: new_text.rindex('"}')]
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset mid-arguments"
)
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)[
:-2
]
logger.debug("finding %s in %s", new_text, cur_arguments_json)
if new_text not in cur_arguments_json:
return None
arguments_delta = cur_arguments_json[
: cur_arguments_json.rindex(new_text) + len(new_text)
]
logger.debug(
"First tokens in arguments received: %s", arguments_delta
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
logger.debug(
"Searching for diff between \n%s\n%s",
cur_args_json,
prev_args_json,
)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json
)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus from http import HTTPStatus
from typing import Final, cast from typing import cast
import jinja2 import jinja2
import numpy as np import numpy as np
...@@ -11,8 +11,18 @@ from fastapi import Request ...@@ -11,8 +11,18 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo from vllm.entrypoints.openai.chat_completion.protocol import (
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
ClassificationServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
...@@ -29,68 +39,60 @@ from vllm.pooling_params import PoolingParams ...@@ -29,68 +39,60 @@ from vllm.pooling_params import PoolingParams
logger = init_logger(__name__) logger = init_logger(__name__)
ClassificationServeContext = ServeContext[ClassificationRequest] class ClassificationMixin(OpenAIServing):
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
class ServingClassification(OpenAIServing): trust_request_chat_template: bool
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess( async def _preprocess(
self, self,
ctx: ClassificationServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
""" """
Process classification inputs: tokenize text, resolve adapters, Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs. and prepare model-specific inputs.
""" """
ctx = cast(ClassificationServeContext, ctx)
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) request_obj = ctx.request
if isinstance(ctx.request, ClassificationChatRequest): if isinstance(request_obj, ClassificationChatRequest):
error_check_ret = self._validate_chat_template( chat_request = request_obj
request_chat_template=ctx.request.chat_template, messages = chat_request.messages
chat_template_kwargs=ctx.request.chat_template_kwargs, trust_request_chat_template = getattr(
trust_request_chat_template=self.trust_request_chat_template, self,
"trust_request_chat_template",
False,
)
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
) )
if error_check_ret: if ret:
return error_check_ret return ret
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
ctx.request, cast(ChatCompletionRequest, chat_request),
self.renderer, self.renderer,
ctx.request.messages, messages,
chat_template=ctx.request.chat_template or self.chat_template, chat_template=(
chat_template_content_format=self.chat_template_content_format, chat_request.chat_template
add_generation_prompt=ctx.request.add_generation_prompt, or getattr(self, "chat_template", None)
continue_final_message=ctx.request.continue_final_message, ),
add_special_tokens=ctx.request.add_special_tokens, chat_template_content_format=cast(
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
) )
ctx.engine_prompts = engine_prompts ctx.engine_prompts = engine_prompts
elif isinstance(ctx.request, ClassificationCompletionRequest): elif isinstance(request_obj, ClassificationCompletionRequest):
input_data = ctx.request.input completion_request = request_obj
input_data = completion_request.input
if input_data in (None, ""): if input_data in (None, ""):
return self.create_error_response( return self.create_error_response(
"Input or messages must be provided", "Input or messages must be provided",
...@@ -104,10 +106,13 @@ class ServingClassification(OpenAIServing): ...@@ -104,10 +106,13 @@ class ServingClassification(OpenAIServing):
prompt_input = cast(str | list[str], input_data) prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input, prompt_or_prompts=prompt_input,
config=self._build_render_config(ctx.request), config=self._build_render_config(completion_request),
) )
else: else:
return self.create_error_response("Invalid classification request type") return self.create_error_response(
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return None return None
...@@ -117,14 +122,13 @@ class ServingClassification(OpenAIServing): ...@@ -117,14 +122,13 @@ class ServingClassification(OpenAIServing):
def _build_response( def _build_response(
self, self,
ctx: ClassificationServeContext, ctx: ServeContext,
) -> ClassificationResponse | ErrorResponse: ) -> ClassificationResponse | ErrorResponse:
""" """
Convert model outputs to a formatted classification response Convert model outputs to a formatted classification response
with probabilities and labels. with probabilities and labels.
""" """
id2label = getattr(self.model_config.hf_config, "id2label", {}) ctx = cast(ClassificationServeContext, ctx)
items: list[ClassificationData] = [] items: list[ClassificationData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
...@@ -135,7 +139,9 @@ class ServingClassification(OpenAIServing): ...@@ -135,7 +139,9 @@ class ServingClassification(OpenAIServing):
probs = classify_res.probs probs = classify_res.probs
predicted_index = int(np.argmax(probs)) predicted_index = int(np.argmax(probs))
label = id2label.get(predicted_index) label = getattr(self.model_config.hf_config, "id2label", {}).get(
predicted_index
)
item = ClassificationData( item = ClassificationData(
index=idx, index=idx,
...@@ -168,6 +174,32 @@ class ServingClassification(OpenAIServing): ...@@ -168,6 +174,32 @@ class ServingClassification(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify( async def create_classify(
self, self,
request: ClassificationRequest, request: ClassificationRequest,
...@@ -183,11 +215,11 @@ class ServingClassification(OpenAIServing): ...@@ -183,11 +215,11 @@ class ServingClassification(OpenAIServing):
request_id=request_id, request_id=request_id,
) )
return await self.handle(ctx) # type: ignore[return-value] return await super().handle(ctx) # type: ignore
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ClassificationServeContext, ctx: ServeContext[ClassificationRequest],
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
......
...@@ -6,13 +6,21 @@ from typing import Any, Final, cast ...@@ -6,13 +6,21 @@ from typing import Any, Final, cast
import torch import torch
from fastapi import Request from fastapi import Request
from typing_extensions import assert_never from fastapi.responses import Response
from typing_extensions import assert_never, override
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo from vllm.entrypoints.openai.engine.protocol import (
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse, EmbeddingBytesResponse,
...@@ -25,11 +33,19 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -25,11 +33,19 @@ from vllm.entrypoints.pooling.embed.protocol import (
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import (
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes, encode_pooling_bytes,
encode_pooling_output, encode_pooling_output,
) )
...@@ -37,33 +53,9 @@ from vllm.utils.serial_utils import ( ...@@ -37,33 +53,9 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
EmbeddingServeContext = ServeContext[EmbeddingRequest] class EmbeddingMixin(OpenAIServing):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class OpenAIServingEmbedding(OpenAIServing):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
pooler_config = self.model_config.pooler_config pooler_config = self.model_config.pooler_config
...@@ -77,41 +69,32 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -77,41 +69,32 @@ class OpenAIServingEmbedding(OpenAIServing):
else None else None
) )
@override
async def _preprocess( async def _preprocess(
self, self,
ctx: EmbeddingServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat( _, ctx.engine_prompts = await self._preprocess_chat(
ctx.request, ctx.request,
self.renderer, self.renderer,
ctx.request.messages, ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template, chat_template=ctx.request.chat_template or ctx.chat_template,
chat_template_content_format=self.chat_template_content_format, chat_template_content_format=ctx.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt, add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message, continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens, add_special_tokens=ctx.request.add_special_tokens,
) )
elif isinstance(ctx.request, EmbeddingCompletionRequest): else:
renderer = self._get_completion_renderer() renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input, prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request), config=self._build_render_config(ctx.request),
) )
else:
return self.create_error_response("Invalid classification request type")
return None return None
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
...@@ -130,15 +113,16 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -130,15 +113,16 @@ class OpenAIServingEmbedding(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
@override
def _build_response( def _build_response(
self, self,
ctx: EmbeddingServeContext, ctx: ServeContext,
) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse: ) -> EmbeddingResponse | Response | ErrorResponse:
final_res_batch_checked = ctx.final_res_batch final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
encoding_format = ctx.request.encoding_format encoding_format: EncodingFormat = ctx.request.encoding_format
embed_dtype = ctx.request.embed_dtype embed_dtype: EmbedDType = ctx.request.embed_dtype
endianness = ctx.request.endianness endianness: Endianness = ctx.request.endianness
def encode_float_base64(): def encode_float_base64():
items: list[EmbeddingResponseData] = [] items: list[EmbeddingResponseData] = []
...@@ -219,8 +203,8 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -219,8 +203,8 @@ class OpenAIServingEmbedding(OpenAIServing):
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
token_ids: list[int], token_ids: list[int],
pooling_params: PoolingParams, pooling_params,
trace_headers: Mapping[str, str] | None, trace_headers,
prompt_idx: int, prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]: ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing.""" """Process a single prompt using chunked processing."""
...@@ -262,7 +246,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -262,7 +246,7 @@ class OpenAIServingEmbedding(OpenAIServing):
def _validate_input( def _validate_input(
self, self,
request: object, request,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TokensPrompt: ) -> TokensPrompt:
...@@ -342,7 +326,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -342,7 +326,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
"""Create a generator for a single prompt using standard processing.""" """Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}" request_id_item = f"{ctx.request_id}-{prompt_index}"
...@@ -363,6 +347,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -363,6 +347,7 @@ class OpenAIServingEmbedding(OpenAIServing):
priority=getattr(ctx.request, "priority", 0), priority=getattr(ctx.request, "priority", 0),
) )
@override
async def _prepare_generators( async def _prepare_generators(
self, self,
ctx: ServeContext, ctx: ServeContext,
...@@ -378,7 +363,9 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -378,7 +363,9 @@ class OpenAIServingEmbedding(OpenAIServing):
return await super()._prepare_generators(ctx) return await super()._prepare_generators(ctx)
# Custom logic for chunked processing # Custom logic for chunked processing
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try: try:
trace_headers = ( trace_headers = (
...@@ -432,9 +419,10 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -432,9 +419,10 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@override
async def _collect_batch( async def _collect_batch(
self, self,
ctx: EmbeddingServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Collect and aggregate batch results """Collect and aggregate batch results
with support for chunked processing. with support for chunked processing.
...@@ -443,6 +431,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -443,6 +431,7 @@ class OpenAIServingEmbedding(OpenAIServing):
minimize memory usage. minimize memory usage.
For regular requests, collects results normally. For regular requests, collects results normally.
""" """
ctx = cast(EmbeddingServeContext, ctx)
try: try:
if ctx.engine_prompts is None: if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
...@@ -538,10 +527,12 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -538,10 +527,12 @@ class OpenAIServingEmbedding(OpenAIServing):
except (ValueError, IndexError): except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = result short_prompts_results[prompt_idx] = cast(
PoolingRequestOutput, result
)
# Finalize aggregated results # Finalize aggregated results
final_res_batch: list[PoolingRequestOutput] = [] final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts): for prompt_idx in range(num_prompts):
...@@ -589,19 +580,49 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -589,19 +580,49 @@ class OpenAIServingEmbedding(OpenAIServing):
f"Failed to aggregate chunks for prompt {prompt_idx}" f"Failed to aggregate chunks for prompt {prompt_idx}"
) )
elif prompt_idx in short_prompts_results: elif prompt_idx in short_prompts_results:
final_res_batch.append(short_prompts_results[prompt_idx]) final_res_batch.append(
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
)
else: else:
return self.create_error_response( return self.create_error_response(
f"Result not found for prompt {prompt_idx}" f"Result not found for prompt {prompt_idx}"
) )
ctx.final_res_batch = final_res_batch ctx.final_res_batch = cast(
list[RequestOutput | PoolingRequestOutput], final_res_batch
)
return None return None
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding( async def create_embedding(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
...@@ -624,13 +645,16 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -624,13 +645,16 @@ class OpenAIServingEmbedding(OpenAIServing):
raw_request=raw_request, raw_request=raw_request,
model_name=model_name, model_name=model_name,
request_id=request_id, request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
) )
return await self.handle(ctx) # type: ignore[return-value] return await super().handle(ctx) # type: ignore
@override
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: EmbeddingServeContext, ctx: ServeContext[EmbeddingRequest],
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
...@@ -642,3 +666,17 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -642,3 +666,17 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return pooling_params return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)
\ No newline at end of file
...@@ -17,10 +17,8 @@ from starlette.background import BackgroundTask, BackgroundTasks ...@@ -17,10 +17,8 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs from vllm import envs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.logger import current_formatter_type, init_logger from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -34,15 +32,11 @@ if TYPE_CHECKING: ...@@ -34,15 +32,11 @@ if TYPE_CHECKING:
StreamOptions, StreamOptions,
) )
from vllm.entrypoints.openai.models.protocol import LoRAModulePath from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
else: else:
ChatCompletionRequest = object ChatCompletionRequest = object
CompletionRequest = object CompletionRequest = object
StreamOptions = object StreamOptions = object
LoRAModulePath = object LoRAModulePath = object
ResponsesRequest = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -217,26 +211,11 @@ def _validate_truncation_size( ...@@ -217,26 +211,11 @@ def _validate_truncation_size(
def get_max_tokens( def get_max_tokens(
max_model_len: int, max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest", request: "ChatCompletionRequest | CompletionRequest",
prompt: TokensPrompt | EmbedsPrompt, input_length: int,
default_sampling_params: dict, default_sampling_params: dict,
) -> int: ) -> int:
# NOTE: Avoid isinstance() for better efficiency max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens
max_tokens: int | None = None
if max_tokens is None:
# ChatCompletionRequest
max_tokens = getattr(request, "max_completion_tokens", None)
if max_tokens is None:
# ResponsesRequest
max_tokens = getattr(request, "max_output_tokens", None)
if max_tokens is None:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
default_max_tokens = max_model_len - input_length default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length) max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
...@@ -87,7 +87,6 @@ if TYPE_CHECKING: ...@@ -87,7 +87,6 @@ if TYPE_CHECKING:
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
VLLM_PLUGINS: list[str] | None = None VLLM_PLUGINS: list[str] | None = None
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None
# Deprecated env variables for profiling, kept for backward compatibility # Deprecated env variables for profiling, kept for backward compatibility
# See also vllm/config/profiler.py and `--profiler-config` argument # See also vllm/config/profiler.py and `--profiler-config` argument
VLLM_TORCH_CUDA_PROFILE: str | None = None VLLM_TORCH_CUDA_PROFILE: str | None = None
...@@ -327,11 +326,16 @@ def use_aot_compile() -> bool: ...@@ -327,11 +326,16 @@ def use_aot_compile() -> bool:
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = ( default_value = (
"1" "1"
if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache() if is_torch_equal_or_newer("2.10.0.dev")
and not disable_compile_cache()
# Disabling AOT_COMPILE for CPU
# See: https://github.com/vllm-project/vllm/issues/32033
and not current_platform.is_cpu()
else "0" else "0"
) )
...@@ -912,13 +916,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -912,13 +916,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv( "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
"VLLM_LORA_RESOLVER_CACHE_DIR", None "VLLM_LORA_RESOLVER_CACHE_DIR", None
), ),
# A remote HF repo(s) containing one or more LoRA adapters, which
# may be downloaded and leveraged as needed. Only works if plugins
# are enabled and VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled.
# Values should be comma separated.
"VLLM_LORA_RESOLVER_HF_REPO_LIST": lambda: os.getenv(
"VLLM_LORA_RESOLVER_HF_REPO_LIST", None
),
# Enables torch CUDA profiling if set to 1. # Enables torch CUDA profiling if set to 1.
# Deprecated, see profiler_config. # Deprecated, see profiler_config.
"VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"), "VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
...@@ -13,7 +12,6 @@ import torch ...@@ -13,7 +12,6 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
...@@ -428,15 +426,3 @@ def set_forward_context( ...@@ -428,15 +426,3 @@ def set_forward_context(
), ),
forward_stats, forward_stats,
) )
_profiling: bool = False
@contextmanager
def set_profilling(profiling):
global _profiling
_profiling = profiling
def get_profilling() -> bool:
global _profiling
return _profiling
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logging_utils.access_log_filter import (
UvicornAccessLogFilter,
create_uvicorn_log_config,
)
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
from vllm.logging_utils.lazy import lazy from vllm.logging_utils.lazy import lazy
from vllm.logging_utils.log_time import logtime from vllm.logging_utils.log_time import logtime
...@@ -12,8 +8,6 @@ from vllm.logging_utils.log_time import logtime ...@@ -12,8 +8,6 @@ from vllm.logging_utils.log_time import logtime
__all__ = [ __all__ = [
"NewLineFormatter", "NewLineFormatter",
"ColoredFormatter", "ColoredFormatter",
"UvicornAccessLogFilter",
"create_uvicorn_log_config",
"lazy", "lazy",
"logtime", "logtime",
] ]
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Access log filter for uvicorn to exclude specific endpoints from logging.
This module provides a logging filter that can be used to suppress access logs
for specific endpoints (e.g., /health, /metrics) to reduce log noise in
production environments.
"""
import logging
from urllib.parse import urlparse
class UvicornAccessLogFilter(logging.Filter):
"""
A logging filter that excludes access logs for specified endpoint paths.
This filter is designed to work with uvicorn's access logger. It checks
the log record's arguments for the request path and filters out records
matching the excluded paths.
Uvicorn access log format:
'%s - "%s %s HTTP/%s" %d'
(client_addr, method, path, http_version, status_code)
Example:
127.0.0.1:12345 - "GET /health HTTP/1.1" 200
Args:
excluded_paths: A list of URL paths to exclude from logging.
Paths are matched exactly.
Example: ["/health", "/metrics"]
"""
def __init__(self, excluded_paths: list[str] | None = None):
super().__init__()
self.excluded_paths = set(excluded_paths or [])
def filter(self, record: logging.LogRecord) -> bool:
"""
Determine if the log record should be logged.
Args:
record: The log record to evaluate.
Returns:
True if the record should be logged, False otherwise.
"""
if not self.excluded_paths:
return True
# This filter is specific to uvicorn's access logs.
if record.name != "uvicorn.access":
return True
# The path is the 3rd argument in the log record's args tuple.
# See uvicorn's access logging implementation for details.
log_args = record.args
if isinstance(log_args, tuple) and len(log_args) >= 3:
path_with_query = log_args[2]
# Get path component without query string.
if isinstance(path_with_query, str):
path = urlparse(path_with_query).path
if path in self.excluded_paths:
return False
return True
def create_uvicorn_log_config(
excluded_paths: list[str] | None = None,
log_level: str = "info",
) -> dict:
"""
Create a uvicorn logging configuration with access log filtering.
This function generates a logging configuration dictionary that can be
passed to uvicorn's `log_config` parameter. It sets up the access log
filter to exclude specified paths.
Args:
excluded_paths: List of URL paths to exclude from access logs.
log_level: The log level for uvicorn loggers.
Returns:
A dictionary containing the logging configuration.
Example:
>>> config = create_uvicorn_log_config(["/health", "/metrics"])
>>> uvicorn.run(app, log_config=config)
"""
config = {
"version": 1,
"disable_existing_loggers": False,
"filters": {
"access_log_filter": {
"()": UvicornAccessLogFilter,
"excluded_paths": excluded_paths or [],
},
},
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(levelprefix)s %(message)s",
"use_colors": None,
},
"access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s', # noqa: E501
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"access": {
"formatter": "access",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"filters": ["access_log_filter"],
},
},
"loggers": {
"uvicorn": {
"handlers": ["default"],
"level": log_level.upper(),
"propagate": False,
},
"uvicorn.error": {
"level": log_level.upper(),
"handlers": ["default"],
"propagate": False,
},
"uvicorn.access": {
"handlers": ["access"],
"level": log_level.upper(),
"propagate": False,
},
},
}
return config
...@@ -62,7 +62,6 @@ def _fused_moe_lora_kernel( ...@@ -62,7 +62,6 @@ def _fused_moe_lora_kernel(
num_experts, num_experts,
lora_ids, lora_ids,
adapter_enabled, adapter_enabled,
max_loras, # <<< PR2: rename, used for masks when grid axis-2 != max_loras
# The stride variables represent how much to increase the ptr by when # The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
...@@ -84,7 +83,6 @@ def _fused_moe_lora_kernel( ...@@ -84,7 +83,6 @@ def _fused_moe_lora_kernel(
num_slice_c: tl.constexpr, num_slice_c: tl.constexpr,
top_k: tl.constexpr, top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
...@@ -106,13 +104,10 @@ def _fused_moe_lora_kernel( ...@@ -106,13 +104,10 @@ def _fused_moe_lora_kernel(
if moe_enabled == 0: if moe_enabled == 0:
# Early exit for the no moe lora case. # Early exit for the no moe lora case.
return return
# The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel. # The grid size on axis 2 is (max_loras + 1) to handle the no-lora case
# This guard ensures we don't access sorted_token_ids / expert_ids / # (lora_id == -1), but sorted_token_ids and expert_ids are allocated with
# num_tokens_post_padded beyond their allocated bounds if an invalid # shape (max_loras, ...). Use (num_programs - 1) for correct bounds checking.
# lora_id somehow appears. Although the caller should pass correct max_loras = tl.num_programs(axis=2) - 1
# max_loras, defensive programming prevents accidental out-of-bounds.
if lora_id >= max_loras:
return
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n # calculate pid_m,pid_n
...@@ -141,11 +136,10 @@ def _fused_moe_lora_kernel( ...@@ -141,11 +136,10 @@ def _fused_moe_lora_kernel(
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
# remove modulo wrap-around offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int32) offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_id + offs_token_id token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load( offs_token = tl.load(
sorted_token_ids_ptr + token_ind, sorted_token_ids_ptr + token_ind,
...@@ -182,13 +176,7 @@ def _fused_moe_lora_kernel( ...@@ -182,13 +176,7 @@ def _fused_moe_lora_kernel(
# GDC wait waits for ALL programs in the prior kernel to complete # GDC wait waits for ALL programs in the prior kernel to complete
# before continuing. # before continuing.
# pre-fetch lora weight # pre-fetch lora weight
# add (offs_bn < N) mask; optional .ca for B b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
if USE_B_L2_CACHE:
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
else:
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
if USE_GDC and not IS_PRIMARY: if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait() tl.extra.cuda.gdc_wait()
a = tl.load( a = tl.load(
...@@ -288,7 +276,6 @@ def _fused_moe_lora_shrink( ...@@ -288,7 +276,6 @@ def _fused_moe_lora_shrink(
num_experts, num_experts,
lora_ids, lora_ids,
adapter_enabled, adapter_enabled,
lora_a_stacked[0].shape[0],
qcurr_hidden_states.stride(0), qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1), qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0), w1_lora_a_stacked.stride(0),
...@@ -305,7 +292,6 @@ def _fused_moe_lora_shrink( ...@@ -305,7 +292,6 @@ def _fused_moe_lora_shrink(
num_slice_c=num_slices, num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num, top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False, MUL_ROUTED_WEIGHT=False,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=True, IS_PRIMARY=True,
**shrink_config, **shrink_config,
) )
...@@ -391,7 +377,6 @@ def _fused_moe_lora_expand( ...@@ -391,7 +377,6 @@ def _fused_moe_lora_expand(
num_experts, num_experts,
lora_ids, lora_ids,
adapter_enabled, adapter_enabled,
lora_b_stacked[0].shape[0],
a_intermediate_cache1.stride(0), a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1), a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0), w1_lora_b_stacked.stride(0),
...@@ -408,7 +393,6 @@ def _fused_moe_lora_expand( ...@@ -408,7 +393,6 @@ def _fused_moe_lora_expand(
num_slice_c=num_slices, num_slice_c=num_slices,
top_k=1, top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=False, IS_PRIMARY=False,
**expand_config, **expand_config,
) )
......
...@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE): ...@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE):
use_overlapped use_overlapped
and not ( and not (
(self.enable_eplb and backend != "allgather_reducescatter") (self.enable_eplb and backend != "allgather_reducescatter")
or self.moe_parallel_config.use_fi_all2allv_kernels or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
) )
and self._shared_experts is not None and self._shared_experts is not None
) )
......
...@@ -193,7 +193,6 @@ class RMSNorm(CustomOp): ...@@ -193,7 +193,6 @@ class RMSNorm(CustomOp):
variance = x_var.pow(2).mean(dim=-1, keepdim=True) variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon) x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype) x = x.to(orig_dtype)
if weight is not None: if weight is not None:
......
...@@ -380,7 +380,6 @@ class ReplicatedLinear(LinearBase): ...@@ -380,7 +380,6 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None, params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
eps: float | None = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
...@@ -392,8 +391,6 @@ class ReplicatedLinear(LinearBase): ...@@ -392,8 +391,6 @@ class ReplicatedLinear(LinearBase):
else: else:
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
self.eps = eps
super().__init__( super().__init__(
input_size, input_size,
output_size, output_size,
...@@ -643,7 +640,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -643,7 +640,6 @@ class ColumnParallelLinear(LinearBase):
if envs.VLLM_USE_NN and not self.is_quantization: if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
...@@ -720,13 +716,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -720,13 +716,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None, params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
eps: float | None = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
disable_tp: bool = False, disable_tp: bool = False,
): ):
self.eps = eps
self.output_sizes = output_sizes self.output_sizes = output_sizes
self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
...@@ -1366,7 +1360,6 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1366,7 +1360,6 @@ class QKVParallelLinear(ColumnParallelLinear):
if envs.VLLM_USE_NN and not self.is_quantization: if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
...@@ -1475,7 +1468,6 @@ class RowParallelLinear(LinearBase): ...@@ -1475,7 +1468,6 @@ class RowParallelLinear(LinearBase):
) )
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.update_param_tp_status() self.update_param_tp_status()
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod) self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
...@@ -1516,7 +1508,6 @@ class RowParallelLinear(LinearBase): ...@@ -1516,7 +1508,6 @@ class RowParallelLinear(LinearBase):
if envs.VLLM_USE_NN and not self.is_quantization: if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader, sharded_weight_loader,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
...@@ -503,9 +502,6 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -503,9 +502,6 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1, dim=-1,
) )
# Check if running on Blackwell (SM100+) for kernel tuning
self.is_blackwell = current_platform.is_device_capability_family(100)
def forward_native( def forward_native(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -887,7 +883,6 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -887,7 +883,6 @@ class MambaMixer2(MambaBase, CustomOp):
state_batch_indices=state_indices_tensor_d_input, state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output, dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
is_blackwell=self.is_blackwell,
) )
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
......
...@@ -286,7 +286,6 @@ def selective_state_update( ...@@ -286,7 +286,6 @@ def selective_state_update(
out=None, out=None,
num_accepted_tokens=None, num_accepted_tokens=None,
cu_seqlens=None, cu_seqlens=None,
is_blackwell=False,
): ):
""" """
Argument: Argument:
...@@ -392,26 +391,17 @@ def selective_state_update( ...@@ -392,26 +391,17 @@ def selective_state_update(
if dst_state_batch_indices is not None if dst_state_batch_indices is not None
else (0, 0) else (0, 0)
) )
# We don't want autotune since it will overwrite the state. # We don't want autotune since it will overwrite the state
# We instead tune by hand based on dstate. # We instead tune by hand.
BLOCK_SIZE_M, num_warps = (
# Default (32, 4)
BLOCK_SIZE_M, num_warps = 4, 8 if dstate <= 16
else (
if dstate <= 16: (16, 4)
BLOCK_SIZE_M, num_warps = 32, 4 if dstate <= 32
elif dstate <= 32: else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
BLOCK_SIZE_M, num_warps = 16, 4 )
elif dstate <= 64: )
BLOCK_SIZE_M, num_warps = 8, 4
else:
# dstate > 64
if is_blackwell:
# Optimized for B200 with dstate>64
BLOCK_SIZE_M, num_warps = 32, 8
elif dstate <= 128:
BLOCK_SIZE_M, num_warps = 4, 4
tie_hdim = ( tie_hdim = (
A.stride(-1) == 0 A.stride(-1) == 0
and A.stride(-2) == 0 and A.stride(-2) == 0
......
...@@ -188,6 +188,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -188,6 +188,7 @@ class CompressedTensorsConfig(QuantizationConfig):
else: else:
return quant_method return quant_method
if isinstance(layer, Attention): if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self) return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
......
...@@ -42,6 +42,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -42,6 +42,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -51,6 +52,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( ...@@ -51,6 +52,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config, make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -64,6 +66,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( ...@@ -64,6 +66,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
...@@ -95,7 +98,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -95,7 +98,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
) )
from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils import W8a8GetCacheJSON
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -240,6 +242,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -240,6 +242,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.group_size = 32 self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN self.mxfp4_backend = NvFp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts self.experts_cls = MarlinExperts
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights( def create_weights(
self, self,
...@@ -316,7 +319,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -316,7 +319,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
) )
def process_weights_after_loading(self, layer: FusedMoE) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False layer.w13_weight_packed.data, requires_grad=False
) )
...@@ -331,12 +334,10 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -331,12 +334,10 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None: if self.moe_quant_config is not None:
self.moe_mk = make_nvfp4_moe_kernel( self.kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
def apply( def apply(
...@@ -346,8 +347,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -346,8 +347,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None assert self.kernel is not None
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -378,10 +379,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -378,10 +379,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
activation_key=None if use_a16 else kNvfp4Dynamic, activation_key=None if use_a16 else kNvfp4Dynamic,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -495,7 +505,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -495,7 +505,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
) )
set_weight_attrs(w2_input_scale, extra_weight_attrs) set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: FusedMoE) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
""" """
Convert NVFP4 MoE weights into kernel format and setup the kernel. Convert NVFP4 MoE weights into kernel format and setup the kernel.
""" """
...@@ -561,33 +571,48 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -561,33 +571,48 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel( self.kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError( if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
f"{self.__class__.__name__} uses the new modular kernel initialization " return None
"logic. This function should not be called." elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
) )
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError( assert self.moe_quant_config is not None
f"{self.__class__.__name__} uses the new modular kernel initialization " assert self.experts_cls is not None
"logic. This function should not be called." return make_nvfp4_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -658,8 +683,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -658,8 +683,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
) )
else: else:
assert self.moe_mk is not None assert self.kernel is not None
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -733,6 +758,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -733,6 +758,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
allow_vllm_cutlass=True, allow_vllm_cutlass=True,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -892,27 +926,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -892,27 +926,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: FusedMoE) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
pass
def process_weights_after_loading(self, layer: FusedMoE) -> None:
# Allow for accessing weights and scales in standard way. # Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight w13 = layer.w13_weight
w2 = layer.w2_weight w2 = layer.w2_weight
...@@ -974,34 +988,49 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -974,34 +988,49 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel( self.kernel, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError( if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
f"{self.__class__.__name__} uses the new modular kernel initialization " return None
"logic. This function should not be called." elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
) )
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
raise ValueError( assert self.moe_quant_config is not None
f"{self.__class__.__name__} uses the new modular kernel initialization " assert self.experts_cls is not None
"logic. This function should not be called." return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -1080,12 +1109,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1080,12 +1109,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic assert not self.is_monolithic
assert self.moe_mk is not None assert self.kernel is not None
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1134,7 +1161,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1134,7 +1161,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
"For INT8 Fused MoE layers, we require channelwise, " "For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales." "dynamic per token quantization. Found static input scales."
) )
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights( def create_weights(
self, self,
...@@ -1203,22 +1229,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1203,22 +1229,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
pass pass
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -1238,8 +1248,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1238,8 +1248,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -1255,8 +1263,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1255,8 +1263,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
use_fused_gate=use_fused_gate,
use_nn_moe=False,
) )
...@@ -1869,7 +1875,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1869,7 +1875,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
return True return True
# TODO @gaoqiong
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
""" """
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
......
...@@ -16,9 +16,6 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ...@@ -16,9 +16,6 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported, cutlass_fp4_supported,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
...@@ -162,17 +159,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -162,17 +159,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
if self.backend == "fbgemm": if self.backend == "fbgemm":
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N
# divisible by 32). fbgemm has its own layout requirements.
if self.backend in ("cutlass", "flashinfer-cutlass"):
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight_packed.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
layer.weights_padding_cols = 0
layer.weight_packed = Parameter( layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False layer.weight_packed.data, requires_grad=False
) )
...@@ -201,8 +187,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -201,8 +187,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
return out return out
output_dtype = x.dtype output_dtype = x.dtype
output_size = layer.output_size_per_partition output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]]
output_shape = [*x.shape[:-1], output_size]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant( x_fp4, x_blockscale = scaled_fp4_quant(
...@@ -212,10 +197,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -212,10 +197,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
backend=self.backend, backend=self.backend,
) )
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = ( mm_args = (
x_fp4, x_fp4,
layer.weight_packed, layer.weight_packed,
...@@ -240,9 +221,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -240,9 +221,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
assert self.backend == "cutlass" assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args) out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)
\ No newline at end of file
...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -51,6 +52,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( ...@@ -51,6 +52,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
...@@ -676,6 +678,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -676,6 +678,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_vllm_cutlass=False, allow_vllm_cutlass=False,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: Module, layer: Module,
...@@ -801,7 +812,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -801,7 +812,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel( def _setup_kernel(
self, self,
layer: FusedMoE, layer: Module,
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
...@@ -833,15 +844,16 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -833,15 +844,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel( self.kernel, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
...@@ -896,19 +908,33 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -896,19 +908,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError( if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
f"{self.__class__.__name__} uses the new modular kernel initialization " return None
"logic. This function should not be called." elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
) )
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
raise ValueError( assert self.moe_quant_config is not None
f"{self.__class__.__name__} uses the new modular kernel initialization " assert self.experts_cls is not None
"logic. This function should not be called." return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -948,7 +974,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -948,7 +974,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, self,
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor,**_, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
...@@ -1002,9 +1028,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1002,9 +1028,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None assert self.kernel is not None
assert not self.is_monolithic assert not self.is_monolithic
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
......
...@@ -304,37 +304,6 @@ class XPUFp8LinearMethod(Fp8LinearMethod): ...@@ -304,37 +304,6 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config) super().__init__(quant_config)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment