Commit 53076d70 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.2' into v0.8.2-ori

parents 322a0be6 9c5c81b0
...@@ -783,7 +783,6 @@ class LLMEngine: ...@@ -783,7 +783,6 @@ class LLMEngine:
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
...@@ -955,12 +954,12 @@ class LLMEngine: ...@@ -955,12 +954,12 @@ class LLMEngine:
""" """
return self.scheduler[virtual_engine].has_unfinished_seqs() return self.scheduler[virtual_engine].has_unfinished_seqs()
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices.""" """Reset prefix cache for all devices."""
success = True success = True
for scheduler in self.scheduler: for scheduler in self.scheduler:
success = success and scheduler.reset_prefix_cache() success = success and scheduler.reset_prefix_cache(device)
return success return success
@staticmethod @staticmethod
......
...@@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest ...@@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs from vllm.utils import Device, deprecate_kwargs
VLLM_RPC_SUCCESS_STR = "SUCCESS" VLLM_RPC_SUCCESS_STR = "SUCCESS"
...@@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum): ...@@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2 STOP_PROFILE = 2
class RPCResetPrefixCacheRequest(Enum): @dataclass
RESET_PREFIX_CACHE = 1 class RPCResetPrefixCacheRequest:
device: Device
class RPCSleepRequest(Enum): class RPCSleepRequest(Enum):
......
...@@ -47,7 +47,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput ...@@ -47,7 +47,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs from vllm.utils import Device, deprecate_kwargs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -684,11 +684,12 @@ class MQLLMEngineClient(EngineClient): ...@@ -684,11 +684,12 @@ class MQLLMEngineClient(EngineClient):
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_prefix_cache(self) -> None: async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache""" """Reset the prefix cache"""
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE, request=RPCResetPrefixCacheRequest(device),
socket=self.input_socket) socket=self.input_socket)
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
......
...@@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams ...@@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import collect_from_async_generator, random_uuid from vllm.utils import Device, collect_from_async_generator, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -81,10 +81,7 @@ class EngineClient(ABC): ...@@ -81,10 +81,7 @@ class EngineClient(ABC):
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError raise NotImplementedError
else: else:
processed_inputs = preprocessor._prompt_to_llm_inputs( processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
prompt,
request_id=request_id,
)
prompt_token_ids = processed_inputs["prompt_token_ids"] prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt") prompt_text = processed_inputs.get("prompt")
...@@ -274,7 +271,8 @@ class EngineClient(ABC): ...@@ -274,7 +271,8 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def reset_prefix_cache(self) -> None: async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache""" """Reset the prefix cache"""
... ...
......
...@@ -42,7 +42,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, ...@@ -42,7 +42,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1187,8 +1188,8 @@ class LLM: ...@@ -1187,8 +1188,8 @@ class LLM:
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.llm_engine.stop_profile() self.llm_engine.stop_profile()
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.llm_engine.reset_prefix_cache() return self.llm_engine.reset_prefix_cache(device)
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
""" """
......
...@@ -85,7 +85,7 @@ from vllm.logger import init_logger ...@@ -85,7 +85,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit) is_valid_ipv6_address, set_ulimit)
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
...@@ -677,8 +677,12 @@ if envs.VLLM_SERVER_DEV_MODE: ...@@ -677,8 +677,12 @@ if envs.VLLM_SERVER_DEV_MODE:
Reset the prefix cache. Note that we currently do not check if the Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server. prefix cache is successfully reset in the API server.
""" """
logger.info("Resetting prefix cache...") device = None
await engine_client(raw_request).reset_prefix_cache() device_str = raw_request.query_params.get("device")
if device_str is not None:
device = Device[device_str.upper()]
logger.info("Resetting prefix cache with specific %s...", str(device))
await engine_client(raw_request).reset_prefix_cache(device)
return Response(status_code=200) return Response(status_code=200)
@router.post("/sleep") @router.post("/sleep")
...@@ -1032,6 +1036,9 @@ async def run_server(args, **uvicorn_kwargs) -> None: ...@@ -1032,6 +1036,9 @@ async def run_server(args, **uvicorn_kwargs) -> None:
host=args.host, host=args.host,
port=args.port, port=args.port,
log_level=args.uvicorn_log_level, log_level=args.uvicorn_log_level,
# NOTE: When the 'disable_uvicorn_access_log' value is True,
# no access log will be output.
access_log=not args.disable_uvicorn_access_log,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE, timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile, ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile, ssl_certfile=args.ssl_certfile,
......
...@@ -89,6 +89,9 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ...@@ -89,6 +89,9 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default="info", default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="Log level for uvicorn.") help="Log level for uvicorn.")
parser.add_argument("--disable-uvicorn-access-log",
action="store_true",
help="Disable uvicorn access log.")
parser.add_argument("--allow-credentials", parser.add_argument("--allow-credentials",
action="store_true", action="store_true",
help="Allow credentials.") help="Allow credentials.")
...@@ -286,13 +289,6 @@ def validate_parsed_serve_args(args: argparse.Namespace): ...@@ -286,13 +289,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
raise TypeError("Error: --enable-reasoning requires " raise TypeError("Error: --enable-reasoning requires "
"--reasoning-parser") "--reasoning-parser")
# Ref https://api-docs.deepseek.com/guides/reasoning_model
# tool call and reasoning cannot be enabled at the same time.
if args.enable_auto_tool_choice and args.enable_reasoning:
raise TypeError(
"Error: --enable-auto-tool-choice and "
"--enable-reasoning cannot be enabled at the same time")
def create_parser_for_docs() -> FlexibleArgumentParser: def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser( parser_for_docs = FlexibleArgumentParser(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from functools import cached_property from functools import cached_property
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
...@@ -76,6 +77,40 @@ class ReasoningParser: ...@@ -76,6 +77,40 @@ class ReasoningParser:
"AbstractReasoningParser.extract_reasoning_content_streaming " "AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!") "has not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.is_reasoning_end has"
"not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_content_ids has"
" not been implemented!")
class ReasoningParserManager: class ReasoningParserManager:
reasoning_parsers: dict[str, type] = {} reasoning_parsers: dict[str, type] = {}
......
...@@ -45,6 +45,19 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -45,6 +45,19 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"DeepSeek R1 reasoning parser could not locate think start/end " "DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!") "tokens in the tokenizer!")
# TODO: need to rebase by PR #14428
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.think_end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
def extract_reasoning_content_streaming( def extract_reasoning_content_streaming(
self, self,
previous_text: str, previous_text: str,
......
...@@ -328,6 +328,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -328,6 +328,9 @@ class OpenAIServingChat(OpenAIServing):
# These are only required in "auto" tool choice case # These are only required in "auto" tool choice case
previous_texts = [""] * num_choices previous_texts = [""] * num_choices
all_previous_token_ids = [[]] * num_choices all_previous_token_ids = [[]] * num_choices
# For reasoning parser and tool call all enabled
added_content_delta_arr = [False] * num_choices
reasoning_end_arr = [False] * num_choices
else: else:
previous_texts, all_previous_token_ids = None, None previous_texts, all_previous_token_ids = None, None
...@@ -477,27 +480,116 @@ class OpenAIServingChat(OpenAIServing): ...@@ -477,27 +480,116 @@ class OpenAIServingChat(OpenAIServing):
delta_message: Optional[DeltaMessage] delta_message: Optional[DeltaMessage]
# handle streaming deltas for tools with named tool_choice # just update previous_texts and previous_token_ids
if tool_choice_function_name: if tool_choice_auto or should_stream_with_reasoning_parsing:
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
])
# handle streaming deltas for tools with "auto" tool choice
elif tool_choice_auto:
assert previous_texts is not None assert previous_texts is not None
assert all_previous_token_ids is not None assert all_previous_token_ids is not None
assert tool_parser is not None
#TODO optimize manipulation of these lists
previous_text = previous_texts[i] previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i] previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list( current_token_ids = previous_token_ids + list(
output.token_ids) output.token_ids)
# handle streaming deltas for tools with named tool_choice
if tool_choice_function_name:
if (self.enable_reasoning
and not reasoning_parser.is_reasoning_end(
previous_token_ids)):
assert reasoning_parser is not None
delta_message = (
reasoning_parser.
extract_reasoning_content_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output.token_ids,
))
# When encountering think end id in delta_token_ids,
# process the `content`. Only keep 'content',
# remove 'reasoning_content'
if reasoning_parser.is_reasoning_end(
list(output.token_ids)):
if delta_message and delta_message.content:
# This need to be added to next `delta_text`
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
else:
# Just to add remaining `content`
if self.enable_reasoning:
delta_text = previous_text + delta_text
current_text = ""
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
])
# handle streaming deltas for tools with "auto" tool choice
# and reasoning parser
elif tool_choice_auto and self.enable_reasoning:
assert tool_parser is not None
assert reasoning_parser is not None
assert added_content_delta_arr is not None
assert reasoning_end_arr is not None
if not reasoning_end_arr[i]:
delta_message = (
reasoning_parser.
extract_reasoning_content_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output.token_ids,
))
# When encountering think end id in delta_token_ids,
# set reasoning status to end.
# Remove the text and token ids related
# to 'reasoning_content'.
if reasoning_parser.is_reasoning_end(
list(output.token_ids)):
reasoning_end_arr[i] = True
current_token_ids = \
reasoning_parser.extract_content_ids(
list(output.token_ids))
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
# handle tool calls only after reasoning is done,
else:
delta_token_ids = list(output.token_ids)
# First time to tool call,
# add the remaining text and token ids
# to delta from previous
if not added_content_delta_arr[i]:
added_content_delta_arr[i] = True
previous_text = ""
previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request))
# when only tool calls
elif tool_choice_auto:
assert tool_parser is not None
delta_message = ( delta_message = (
tool_parser.extract_tool_calls_streaming( tool_parser.extract_tool_calls_streaming(
previous_text=previous_text, previous_text=previous_text,
...@@ -507,23 +599,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -507,23 +599,9 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids=current_token_ids, current_token_ids=current_token_ids,
delta_token_ids=output.token_ids, delta_token_ids=output.token_ids,
request=request)) request=request))
# when only reasoning
# update the previous values for the next iteration
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# reasoning_content cannot be enabled with tool_choice.
# If it is, the tool_choice will be used instead.
elif self.enable_reasoning: elif self.enable_reasoning:
# handle reasoning_content delta
assert reasoning_parser is not None assert reasoning_parser is not None
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list(
output.token_ids)
delta_message = (reasoning_parser. delta_message = (reasoning_parser.
extract_reasoning_content_streaming( extract_reasoning_content_streaming(
previous_text, previous_text,
...@@ -533,15 +611,17 @@ class OpenAIServingChat(OpenAIServing): ...@@ -533,15 +611,17 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids, current_token_ids,
output.token_ids, output.token_ids,
)) ))
# update the previous values for the next iteration
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# handle streaming just a content delta # handle streaming just a content delta
else: else:
delta_message = DeltaMessage(content=delta_text) delta_message = DeltaMessage(content=delta_text)
# update the previous values for the next iteration
if tool_choice_auto or should_stream_with_reasoning_parsing:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# set the previous values for the next iteration # set the previous values for the next iteration
previous_num_tokens[i] += len(output.token_ids) previous_num_tokens[i] += len(output.token_ids)
...@@ -739,24 +819,24 @@ class OpenAIServingChat(OpenAIServing): ...@@ -739,24 +819,24 @@ class OpenAIServingChat(OpenAIServing):
except RuntimeError as e: except RuntimeError as e:
logger.exception("Error in reasoning parser creation.") logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
# If the reasoning parser is enabled,
# tool calls are extracted exclusively from the content.
reasoning_content, content = ( reasoning_content, content = (
reasoning_parser.extract_reasoning_content( reasoning_parser.extract_reasoning_content(
output.text, request=request)) output.text, request=request))
else:
if reasoning_content: reasoning_content = None
message = ChatMessage(role=role, content = output.text
content=content,
reasoning_content=reasoning_content)
else:
message = ChatMessage(role=role, content=output.text)
# if auto tools are not enabled, and a named tool choice using # if auto tools are not enabled, and a named tool choice using
# outlines is not being used # outlines is not being used
elif (not self.enable_auto_tools if (not self.enable_auto_tools
or not self.tool_parser) and not isinstance( or not self.tool_parser) and not isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam): request.tool_choice,
message = ChatMessage(role=role, content=output.text) ChatCompletionNamedToolChoiceParam):
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
# if the request uses tools and specified a tool choice # if the request uses tools and specified a tool choice
elif request.tool_choice and type( elif request.tool_choice and type(
...@@ -766,18 +846,21 @@ class OpenAIServingChat(OpenAIServing): ...@@ -766,18 +846,21 @@ class OpenAIServingChat(OpenAIServing):
tokenizer, MistralTokenizer) else ToolCall tokenizer, MistralTokenizer) else ToolCall
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning_content=reasoning_content,
content="", content="",
tool_calls=[ tool_calls=[
tool_call_class(function=FunctionCall( tool_call_class(function=FunctionCall(
name=request.tool_choice.function.name, name=request.tool_choice.function.name,
arguments=output.text)) arguments=content))
]) ])
# if the request doesn't use tool choice # if the request doesn't use tool choice
# OR specifies to not use a tool # OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none": elif not request.tool_choice or request.tool_choice == "none":
message = ChatMessage(role=role, content=output.text) message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
# handle when there are tools and tool choice is auto # handle when there are tools and tool choice is auto
elif request.tools and ( elif request.tools and (
...@@ -792,20 +875,23 @@ class OpenAIServingChat(OpenAIServing): ...@@ -792,20 +875,23 @@ class OpenAIServingChat(OpenAIServing):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
tool_call_info = tool_parser.extract_tool_calls( tool_call_info = tool_parser.extract_tool_calls(
output.text, request=request) content if content is not None else "", request=request)
# In the OpenAI API the finish_reason is "tools_called" # In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool # if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls # call. The same is not true for named function calls
auto_tools_called = tool_call_info.tools_called auto_tools_called = tool_call_info.tools_called
if tool_call_info.tools_called: if tool_call_info.tools_called:
message = ChatMessage(role=role, message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=tool_call_info.content, content=tool_call_info.content,
tool_calls=tool_call_info.tool_calls) tool_calls=tool_call_info.tool_calls)
else: else:
# FOR NOW make it a chat message; we will have to detect # FOR NOW make it a chat message; we will have to detect
# the type to make it later. # the type to make it later.
message = ChatMessage(role=role, content=output.text) message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
# undetermined case that is still important to handle # undetermined case that is still important to handle
else: else:
...@@ -813,7 +899,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -813,7 +899,9 @@ class OpenAIServingChat(OpenAIServing):
"Error in chat_completion_full_generator - cannot determine" "Error in chat_completion_full_generator - cannot determine"
" if tools should be extracted. Returning a standard chat " " if tools should be extracted. Returning a standard chat "
"completion.") "completion.")
message = ChatMessage(role=role, content=output.text) message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=output.index, index=output.index,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import hashlib
import os import os
import tempfile import tempfile
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable, Optional
...@@ -40,11 +41,8 @@ if TYPE_CHECKING: ...@@ -40,11 +41,8 @@ if TYPE_CHECKING:
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_MOE_PREPACK: bool = True VLLM_CPU_MOE_PREPACK: bool = True
VLLM_OPENVINO_DEVICE: str = "CPU"
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
...@@ -74,10 +72,13 @@ if TYPE_CHECKING: ...@@ -74,10 +72,13 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_DISABLE_COMPILE_CACHE: bool = False
Q_SCALE_CONSTANT: int = 200
K_SCALE_CONSTANT: int = 200 K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100 V_SCALE_CONSTANT: int = 100
VLLM_SERVER_DEV_MODE: bool = False VLLM_SERVER_DEV_MODE: bool = False
...@@ -94,6 +95,7 @@ if TYPE_CHECKING: ...@@ -94,6 +95,7 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT: int = 0 VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -126,7 +128,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -126,7 +128,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ================== # ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default), # Target device of vLLM, supporting [cuda (by default),
# rocm, neuron, cpu, openvino] # rocm, neuron, cpu]
"VLLM_TARGET_DEVICE": "VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"),
...@@ -353,28 +355,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -353,28 +355,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CPU_MOE_PREPACK": "VLLM_CPU_MOE_PREPACK":
lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))),
# OpenVINO device selection
# default is CPU
"VLLM_OPENVINO_DEVICE":
lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(),
# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_OPENVINO_KVCACHE_SPACE", "0")),
# OpenVINO KV cache precision
# default is bf16 if natively supported by platform, otherwise f16
# To enable KV cache compression, please, explicitly specify u8
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION":
lambda: os.getenv("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", None),
# Enables weights compression during model export via HF Optimum
# default is False
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS":
lambda:
(os.environ.get("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", "0").lower() in
("on", "true", "1")),
# If the env var is set, then all workers will execute as separate # If the env var is set, then all workers will execute as separate
# processes from the engine, and we use the same mechanism to trigger # processes from the engine, and we use the same mechanism to trigger
# execution on all workers. # execution on all workers.
...@@ -444,6 +424,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -444,6 +424,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_XLA_CACHE_PATH", "VLLM_XLA_CACHE_PATH",
os.path.join(get_default_cache_root(), "vllm", "xla_cache"), os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
)), )),
# If set, assert on XLA recompilation after each execution step.
"VLLM_XLA_CHECK_RECOMPILATION":
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))),
"VLLM_FUSED_MOE_CHUNK_SIZE": "VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),
...@@ -521,16 +505,31 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -521,16 +505,31 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V1": "VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),
# Disable aiter ops unless specifically enabled.
# Acts as a parent switch to enable the rest of the other operations.
"VLLM_ROCM_USE_AITER":
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),
# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
("true", "1")),
# Pad the fp8 weights to 256 bytes for ROCm # Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING": "VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache # Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT": "K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
# Divisor for dynamic value scale factor calculation for FP8 KV Cache # Divisor for dynamic value scale factor calculation for FP8 KV Cache
"V_SCALE_CONSTANT": "V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
# If set, enable multiprocessing in LLM for the V1 code path. # If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING": "VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
...@@ -618,6 +617,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -618,6 +617,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# an environment with potentially malicious users. # an environment with potentially malicious users.
"VLLM_V0_USE_OUTLINES_CACHE": "VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
} }
# end-env-vars-definition # end-env-vars-definition
...@@ -648,3 +652,43 @@ def set_vllm_use_v1(use_v1: bool): ...@@ -648,3 +652,43 @@ def set_vllm_use_v1(use_v1: bool):
"explicitly by the user. Please raise this as a Github " "explicitly by the user. Please raise this as a Github "
"Issue and explicitly set VLLM_USE_V1=0 or 1.") "Issue and explicitly set VLLM_USE_V1=0 or 1.")
os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0"
def compute_hash() -> str:
"""
WARNING: Whenever a new key is added to this environment
variables, ensure that it is included in the factors list if
it affects the computation graph. For example, different values
of VLLM_PP_LAYER_PARTITION will generate different computation
graphs, so it is included in the factors list. The env vars that
affect the choice of different kernels or attention backends should
also be included in the factors list.
"""
factors: list[Any] = []
# summarize environment variables
def factorize(name: str):
if __getattr__(name):
factors.append(__getattr__(name))
else:
factors.append("None")
# The values of envs may affects the computation graph.
# TODO(DefTruth): hash all environment variables?
# for key in environment_variables:
# factorize(key)
environment_variables_to_hash = [
"VLLM_PP_LAYER_PARTITION",
"VLLM_MLA_DISABLE",
"VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK",
"VLLM_DP_SIZE",
]
for key in environment_variables_to_hash:
if key in environment_variables:
factorize(key)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
...@@ -16,7 +16,7 @@ import torch ...@@ -16,7 +16,7 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import _check_multiproc_method, get_mp_context, run_method from vllm.utils import _maybe_force_spawn, get_mp_context, run_method
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config): ...@@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config):
in a multiprocessing environment. This should be called by the parent in a multiprocessing environment. This should be called by the parent
process before worker processes are created""" process before worker processes are created"""
_check_multiproc_method() _maybe_force_spawn()
# Configure thread parallelism if OMP_NUM_THREADS isn't set # Configure thread parallelism if OMP_NUM_THREADS isn't set
# #
......
...@@ -340,6 +340,8 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -340,6 +340,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
and v not in self.non_carry_over_env_vars and v not in self.non_carry_over_env_vars
] ]
env_vars_to_copy.extend(current_platform.additional_env_vars)
# Copy existing env vars to each worker's args # Copy existing env vars to each worker's args
for args in all_args_to_update_environment_variables: for args in all_args_to_update_environment_variables:
# TODO: refactor platform-specific env vars # TODO: refactor platform-specific env vars
...@@ -559,6 +561,15 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -559,6 +561,15 @@ class RayDistributedExecutor(DistributedExecutorBase):
envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL) envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution,
# i.e., the distributed execution that includes model forward runs and
# intermediate tensor communications, in the case of vllm.
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
logger.info("RAY_CGRAPH_get_timeout is set to %s",
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
with InputNode() as input_data: with InputNode() as input_data:
# Example DAG: PP=2, TP=4 # Example DAG: PP=2, TP=4
# #
......
...@@ -17,7 +17,7 @@ from vllm.utils import get_ip ...@@ -17,7 +17,7 @@ from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -284,8 +284,9 @@ def initialize_ray_cluster( ...@@ -284,8 +284,9 @@ def initialize_ray_cluster(
assert_ray_available() assert_ray_available()
from vllm.platforms import current_platform from vllm.platforms import current_platform
# Connect to a ray cluster. if ray.is_initialized():
if current_platform.is_rocm() or current_platform.is_xpu(): logger.info("Ray is already initialized. Skipping Ray initialization.")
elif current_platform.is_rocm() or current_platform.is_xpu():
# Try to connect existing ray instance and create a new one if not found # Try to connect existing ray instance and create a new one if not found
try: try:
ray.init("auto", ignore_reinit_error=True) ray.init("auto", ignore_reinit_error=True)
...@@ -299,19 +300,21 @@ def initialize_ray_cluster( ...@@ -299,19 +300,21 @@ def initialize_ray_cluster(
else: else:
ray.init(address=ray_address, ignore_reinit_error=True) ray.init(address=ray_address, ignore_reinit_error=True)
if parallel_config.placement_group:
# Placement group is already set.
return
device_str = current_platform.ray_device_key device_str = current_platform.ray_device_key
if not device_str: if not device_str:
raise ValueError( raise ValueError(
f"current platform {current_platform.device_name} does not " f"current platform {current_platform.device_name} does not "
"support ray.") "support ray.")
# Create placement group for worker processes # Create or get the placement group for worker processes
current_placement_group = ray.util.get_current_placement_group() if parallel_config.placement_group:
current_placement_group = parallel_config.placement_group
else:
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group: if current_placement_group:
logger.info("Using the existing placement group")
# We are in a placement group # We are in a placement group
bundles = current_placement_group.bundle_specs bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group. # Verify that we can use the placement group.
...@@ -331,6 +334,8 @@ def initialize_ray_cluster( ...@@ -331,6 +334,8 @@ def initialize_ray_cluster(
f"Required number of devices: {parallel_config.world_size}. " f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.") f"Total number of devices: {device_bundles}.")
else: else:
logger.info("No current placement group found. "
"Creating a new placement group.")
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
# Log a warning message and delay resource allocation failure response. # Log a warning message and delay resource allocation failure response.
# Avoid immediate rejection to allow user-initiated placement group # Avoid immediate rejection to allow user-initiated placement group
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from vllm import envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
device_capability = current_platform.get_device_capability()
assert device_capability is not None
# 1. default version depending on platform
fa_version = 3 if (device_capability.major == 9
and is_fa_version_supported(3)) else 2
# 2. override if passed by environment
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform "
"defaulting to FA version 2.")
fa_version = 2
if requires_alibi and fa_version == 3:
logger.warning_once("Cannot use FA version 3 with ALiBi, "
"defaulting to FA version 2.")
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None
...@@ -182,7 +182,6 @@ class InputPreprocessor: ...@@ -182,7 +182,6 @@ class InputPreprocessor:
def _tokenize_prompt( def _tokenize_prompt(
self, self,
prompt: str, prompt: str,
request_id: str,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
) -> list[int]: ) -> list[int]:
""" """
...@@ -202,15 +201,13 @@ class InputPreprocessor: ...@@ -202,15 +201,13 @@ class InputPreprocessor:
"do_lower_case", False)): "do_lower_case", False)):
prompt = prompt.lower() prompt = prompt.lower()
return tokenizer.encode(request_id=request_id, return tokenizer.encode(prompt=prompt,
prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
async def _tokenize_prompt_async( async def _tokenize_prompt_async(
self, self,
prompt: str, prompt: str,
request_id: str,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
) -> list[int]: ) -> list[int]:
"""Async version of :meth:`_tokenize_prompt`.""" """Async version of :meth:`_tokenize_prompt`."""
...@@ -222,7 +219,6 @@ class InputPreprocessor: ...@@ -222,7 +219,6 @@ class InputPreprocessor:
# appending an EOS token to the prompt which disrupts generation. # appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False add_special_tokens = False
return await tokenizer.encode_async( return await tokenizer.encode_async(
request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
...@@ -309,7 +305,6 @@ class InputPreprocessor: ...@@ -309,7 +305,6 @@ class InputPreprocessor:
def _prompt_to_llm_inputs( def _prompt_to_llm_inputs(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> SingletonInputs: ) -> SingletonInputs:
...@@ -318,7 +313,6 @@ class InputPreprocessor: ...@@ -318,7 +313,6 @@ class InputPreprocessor:
Arguments: Arguments:
* request_id
* prompt: single encoder or decoder input prompt * prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts * lora_request: this is only valid for decoder prompts
* return_mm_hashes: whether to return multimodal hashes * return_mm_hashes: whether to return multimodal hashes
...@@ -333,7 +327,6 @@ class InputPreprocessor: ...@@ -333,7 +327,6 @@ class InputPreprocessor:
prompt_text = parsed["content"] prompt_text = parsed["content"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -384,7 +377,6 @@ class InputPreprocessor: ...@@ -384,7 +377,6 @@ class InputPreprocessor:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -400,7 +392,6 @@ class InputPreprocessor: ...@@ -400,7 +392,6 @@ class InputPreprocessor:
async def _prompt_to_llm_inputs_async( async def _prompt_to_llm_inputs_async(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> SingletonInputs: ) -> SingletonInputs:
...@@ -411,7 +402,6 @@ class InputPreprocessor: ...@@ -411,7 +402,6 @@ class InputPreprocessor:
prompt_text = parsed["content"] prompt_text = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -460,7 +450,6 @@ class InputPreprocessor: ...@@ -460,7 +450,6 @@ class InputPreprocessor:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -560,7 +549,6 @@ class InputPreprocessor: ...@@ -560,7 +549,6 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
For encoder/decoder models only: For encoder/decoder models only:
...@@ -587,7 +575,6 @@ class InputPreprocessor: ...@@ -587,7 +575,6 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: an input prompt * prompt: an input prompt
* request_id
Returns: Returns:
...@@ -598,16 +585,11 @@ class InputPreprocessor: ...@@ -598,16 +585,11 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
encoder_inputs = self._prompt_to_llm_inputs( encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"], prompt["encoder_prompt"])
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None decoder_inputs = None
else: else:
decoder_inputs = self._prompt_to_llm_inputs( decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
decoder_input,
request_id=request_id,
)
# For multimodal model, override decoder prompt from processor # For multimodal model, override decoder prompt from processor
# with explicit decoder prompt. # with explicit decoder prompt.
if self.model_config.is_multimodal_model and ( if self.model_config.is_multimodal_model and (
...@@ -616,10 +598,7 @@ class InputPreprocessor: ...@@ -616,10 +598,7 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs)) encoder_inputs, decoder_inputs))
else: else:
inputs = self._prompt_to_llm_inputs( inputs = self._prompt_to_llm_inputs(prompt)
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and ( if self.model_config.is_multimodal_model and (
self._can_process_multimodal()): self._can_process_multimodal()):
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -636,7 +615,6 @@ class InputPreprocessor: ...@@ -636,7 +615,6 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async( async def _process_encoder_decoder_prompt_async(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`.""" """Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_inputs: SingletonInputs encoder_inputs: SingletonInputs
...@@ -644,18 +622,13 @@ class InputPreprocessor: ...@@ -644,18 +622,13 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._prompt_to_llm_inputs_async( encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"], prompt["encoder_prompt"])
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_inputs = await encoder_task encoder_inputs = await encoder_task
decoder_inputs = None decoder_inputs = None
else: else:
decoder_task = self._prompt_to_llm_inputs_async( decoder_task = self._prompt_to_llm_inputs_async(decoder_input)
decoder_input,
request_id=request_id,
)
encoder_inputs, decoder_inputs = await asyncio.gather( encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task) encoder_task, decoder_task)
...@@ -668,10 +641,7 @@ class InputPreprocessor: ...@@ -668,10 +641,7 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs)) encoder_inputs, decoder_inputs))
else: else:
inputs = await self._prompt_to_llm_inputs_async( inputs = await self._prompt_to_llm_inputs_async(prompt)
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and ( if self.model_config.is_multimodal_model and (
self._can_process_multimodal()): self._can_process_multimodal()):
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -704,7 +674,6 @@ class InputPreprocessor: ...@@ -704,7 +674,6 @@ class InputPreprocessor:
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
...@@ -716,7 +685,6 @@ class InputPreprocessor: ...@@ -716,7 +685,6 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: input prompt * prompt: input prompt
* request_id
* lora_request * lora_request
* prompt_adapter_request * prompt_adapter_request
* return_mm_hashes * return_mm_hashes
...@@ -728,7 +696,6 @@ class InputPreprocessor: ...@@ -728,7 +696,6 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs( prompt_comps = self._prompt_to_llm_inputs(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
...@@ -741,7 +708,6 @@ class InputPreprocessor: ...@@ -741,7 +708,6 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async( async def _process_decoder_only_prompt_async(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
...@@ -749,7 +715,6 @@ class InputPreprocessor: ...@@ -749,7 +715,6 @@ class InputPreprocessor:
"""Async version of :meth:`_process_decoder_only_prompt`.""" """Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._prompt_to_llm_inputs_async( prompt_comps = await self._prompt_to_llm_inputs_async(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
...@@ -762,7 +727,6 @@ class InputPreprocessor: ...@@ -762,7 +727,6 @@ class InputPreprocessor:
def preprocess( def preprocess(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
...@@ -774,10 +738,7 @@ class InputPreprocessor: ...@@ -774,10 +738,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1.") "returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(prompt)
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
...@@ -786,7 +747,6 @@ class InputPreprocessor: ...@@ -786,7 +747,6 @@ class InputPreprocessor:
# Decoder-only operation # Decoder-only operation
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
...@@ -795,7 +755,6 @@ class InputPreprocessor: ...@@ -795,7 +755,6 @@ class InputPreprocessor:
async def preprocess_async( async def preprocess_async(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
...@@ -807,10 +766,7 @@ class InputPreprocessor: ...@@ -807,10 +766,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1.") "returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(prompt)
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
...@@ -819,7 +775,6 @@ class InputPreprocessor: ...@@ -819,7 +775,6 @@ class InputPreprocessor:
# Decoder-only operation # Decoder-only operation
return await self._process_decoder_only_prompt_async( return await self._process_decoder_only_prompt_async(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
......
...@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final ...@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
import torch import torch
import vllm.envs as envs
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
...@@ -42,8 +43,15 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -42,8 +43,15 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras, self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras,
max_num_batched_tokens, max_num_batched_tokens,
device=device) device=device)
# When cudagraph capture size is greater than max_num_seqs (max_batches,
# here), V0 captures the graph as if max_num_seqs is set to
# the capture size.
# V1 doesn't have this problem and always respects max_num_seqs.
max_num_prompts = (max_batches
if envs.VLLM_USE_V1 else max_num_batched_tokens)
self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras, self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras,
max_batches, max_num_prompts,
device=device) device=device)
def update_metadata( def update_metadata(
......
...@@ -79,6 +79,12 @@ def maybe_backend_fallback( ...@@ -79,6 +79,12 @@ def maybe_backend_fallback(
"xgrammar does not support Lark grammars and the " "xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines") "grammar failed to convert to GBNF.", "outlines")
elif guided_params.json_object:
# https://github.com/mlc-ai/xgrammar/issues/256
fallback_or_error(guided_params,
"xgrammar does not support json_object.",
"guidance")
# If the xgrammar module cannot be imported successfully, # If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback. # we should still allow users to use guided decoding with a fallback.
elif not xgr_installed: elif not xgr_installed:
...@@ -88,9 +94,9 @@ def maybe_backend_fallback( ...@@ -88,9 +94,9 @@ def maybe_backend_fallback(
if (guided_params.backend_name == "outlines" if (guided_params.backend_name == "outlines"
and guided_params.json_object is not None): and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar # outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params, fallback_or_error(guided_params,
"outlines does not support json_object.", "xgrammar") "outlines does not support json_object.", "guidance")
return guided_params return guided_params
...@@ -122,10 +128,15 @@ async def get_guided_decoding_logits_processor( ...@@ -122,10 +128,15 @@ async def get_guided_decoding_logits_processor(
get_local_xgrammar_guided_decoding_logits_processor) get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor( return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner) guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError( raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. " f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)
def get_local_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor(
...@@ -155,7 +166,13 @@ def get_local_guided_decoding_logits_processor( ...@@ -155,7 +166,13 @@ def get_local_guided_decoding_logits_processor(
get_local_xgrammar_guided_decoding_logits_processor) get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor( return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner) guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError( raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. " f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)
# SPDX-License-Identifier: Apache-2.0
from re import escape as regex_escape
import llguidance
from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
GuidanceLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
def get_local_guidance_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
"""
grm = ""
if guided_params.json:
grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
elif guided_params.json_object:
grm = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
elif guided_params.regex:
grm = llguidance.grammar_from("regex", guided_params.regex)
elif guided_params.choice:
# choice just uses regex
choices = (regex_escape(str(choice))
for choice in guided_params.choice)
choices_regex = "(" + "|".join(choices) + ")"
grm = llguidance.grammar_from("regex", choices_regex)
elif guided_params.grammar:
# this supports Lark and GBNF
grm = llguidance.grammar_from("grammar", guided_params.grammar)
if grm:
return GuidanceLogitsProcessor(grm, tokenizer)
raise ValueError("Unknown guided decoding mode")
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any, List
import llguidance
import llguidance.hf
import llguidance.torch
import torch
from transformers import PreTrainedTokenizerBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class GuidanceLogitsProcessor:
"""Base Guidance Logits Processor"""
cached_tokenizers: dict[str, Any] = {}
def __init__(
self,
grammar: str,
tokenizer: PreTrainedTokenizerBase,
) -> None:
"""Base Guidance Logits Processor
Args:
grammar (str)
grammar to guide the generation
tokenizer (PreTrainedTokenizerBase)
model's tokenizer
"""
self.grammar = grammar
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer.name_or_path
self.new_sampling = False
self.initialized = False
def _initialize(self):
if self.initialized:
return
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
None)
if ll_tokenizer is None:
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer
self.ll_tokenizer = ll_tokenizer
self.ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
self.grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
# create reusable bitmask
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size)
self.initialized = True
def __call__(
self,
input_ids: List[int],
scores: torch.Tensor,
) -> torch.Tensor:
# we initialize the guidance model here
# to avoid pickling ll_tokenizer and ll_interpreter
self._initialize()
if self.new_sampling and len(input_ids) > 0:
self.ll_matcher.consume_token(input_ids[-1])
err = self.ll_matcher.get_error()
if err:
logger.warning("Error in LLMatcher: %s", err)
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
0)
llguidance.torch.apply_token_bitmask_inplace(
scores, self.bitmask.to(scores.device))
self.new_sampling = True
return scores
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment