Commit 53076d70 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.2' into v0.8.2-ori

parents 322a0be6 9c5c81b0
......@@ -783,7 +783,6 @@ class LLMEngine:
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
......@@ -955,12 +954,12 @@ class LLMEngine:
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()
def reset_prefix_cache(self) -> bool:
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices."""
success = True
for scheduler in self.scheduler:
success = success and scheduler.reset_prefix_cache()
success = success and scheduler.reset_prefix_cache(device)
return success
@staticmethod
......
......@@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs
from vllm.utils import Device, deprecate_kwargs
VLLM_RPC_SUCCESS_STR = "SUCCESS"
......@@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2
class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE = 1
@dataclass
class RPCResetPrefixCacheRequest:
device: Device
class RPCSleepRequest(Enum):
......
......@@ -47,7 +47,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
from vllm.utils import Device, deprecate_kwargs
logger = init_logger(__name__)
......@@ -684,11 +684,12 @@ class MQLLMEngineClient(EngineClient):
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_prefix_cache(self) -> None:
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""
await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
request=RPCResetPrefixCacheRequest(device),
socket=self.input_socket)
async def sleep(self, level: int = 1) -> None:
......
......@@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import collect_from_async_generator, random_uuid
from vllm.utils import Device, collect_from_async_generator, random_uuid
logger = init_logger(__name__)
......@@ -81,10 +81,7 @@ class EngineClient(ABC):
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = preprocessor._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt")
......@@ -274,7 +271,8 @@ class EngineClient(ABC):
...
@abstractmethod
async def reset_prefix_cache(self) -> None:
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""
...
......
......@@ -42,7 +42,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
logger = init_logger(__name__)
......@@ -1187,8 +1188,8 @@ class LLM:
def stop_profile(self) -> None:
self.llm_engine.stop_profile()
def reset_prefix_cache(self) -> bool:
return self.llm_engine.reset_prefix_cache()
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.llm_engine.reset_prefix_cache(device)
def sleep(self, level: int = 1):
"""
......
......@@ -85,7 +85,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
from vllm.version import __version__ as VLLM_VERSION
......@@ -677,8 +677,12 @@ if envs.VLLM_SERVER_DEV_MODE:
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache()
device = None
device_str = raw_request.query_params.get("device")
if device_str is not None:
device = Device[device_str.upper()]
logger.info("Resetting prefix cache with specific %s...", str(device))
await engine_client(raw_request).reset_prefix_cache(device)
return Response(status_code=200)
@router.post("/sleep")
......@@ -1032,6 +1036,9 @@ async def run_server(args, **uvicorn_kwargs) -> None:
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
# NOTE: When the 'disable_uvicorn_access_log' value is True,
# no access log will be output.
access_log=not args.disable_uvicorn_access_log,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
......
......@@ -89,6 +89,9 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="Log level for uvicorn.")
parser.add_argument("--disable-uvicorn-access-log",
action="store_true",
help="Disable uvicorn access log.")
parser.add_argument("--allow-credentials",
action="store_true",
help="Allow credentials.")
......@@ -286,13 +289,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
raise TypeError("Error: --enable-reasoning requires "
"--reasoning-parser")
# Ref https://api-docs.deepseek.com/guides/reasoning_model
# tool call and reasoning cannot be enabled at the same time.
if args.enable_auto_tool_choice and args.enable_reasoning:
raise TypeError(
"Error: --enable-auto-tool-choice and "
"--enable-reasoning cannot be enabled at the same time")
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
......
# SPDX-License-Identifier: Apache-2.0
import os
from abc import abstractmethod
from collections.abc import Sequence
from functools import cached_property
from typing import Callable, Optional, Union
......@@ -76,6 +77,40 @@ class ReasoningParser:
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.is_reasoning_end has"
"not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_content_ids has"
" not been implemented!")
class ReasoningParserManager:
reasoning_parsers: dict[str, type] = {}
......
......@@ -45,6 +45,19 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
# TODO: need to rebase by PR #14428
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.think_end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
def extract_reasoning_content_streaming(
self,
previous_text: str,
......
......@@ -328,6 +328,9 @@ class OpenAIServingChat(OpenAIServing):
# These are only required in "auto" tool choice case
previous_texts = [""] * num_choices
all_previous_token_ids = [[]] * num_choices
# For reasoning parser and tool call all enabled
added_content_delta_arr = [False] * num_choices
reasoning_end_arr = [False] * num_choices
else:
previous_texts, all_previous_token_ids = None, None
......@@ -477,27 +480,116 @@ class OpenAIServingChat(OpenAIServing):
delta_message: Optional[DeltaMessage]
# handle streaming deltas for tools with named tool_choice
if tool_choice_function_name:
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
])
# handle streaming deltas for tools with "auto" tool choice
elif tool_choice_auto:
# just update previous_texts and previous_token_ids
if tool_choice_auto or should_stream_with_reasoning_parsing:
assert previous_texts is not None
assert all_previous_token_ids is not None
assert tool_parser is not None
#TODO optimize manipulation of these lists
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list(
output.token_ids)
# handle streaming deltas for tools with named tool_choice
if tool_choice_function_name:
if (self.enable_reasoning
and not reasoning_parser.is_reasoning_end(
previous_token_ids)):
assert reasoning_parser is not None
delta_message = (
reasoning_parser.
extract_reasoning_content_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output.token_ids,
))
# When encountering think end id in delta_token_ids,
# process the `content`. Only keep 'content',
# remove 'reasoning_content'
if reasoning_parser.is_reasoning_end(
list(output.token_ids)):
if delta_message and delta_message.content:
# This need to be added to next `delta_text`
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
else:
# Just to add remaining `content`
if self.enable_reasoning:
delta_text = previous_text + delta_text
current_text = ""
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
])
# handle streaming deltas for tools with "auto" tool choice
# and reasoning parser
elif tool_choice_auto and self.enable_reasoning:
assert tool_parser is not None
assert reasoning_parser is not None
assert added_content_delta_arr is not None
assert reasoning_end_arr is not None
if not reasoning_end_arr[i]:
delta_message = (
reasoning_parser.
extract_reasoning_content_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output.token_ids,
))
# When encountering think end id in delta_token_ids,
# set reasoning status to end.
# Remove the text and token ids related
# to 'reasoning_content'.
if reasoning_parser.is_reasoning_end(
list(output.token_ids)):
reasoning_end_arr[i] = True
current_token_ids = \
reasoning_parser.extract_content_ids(
list(output.token_ids))
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
# handle tool calls only after reasoning is done,
else:
delta_token_ids = list(output.token_ids)
# First time to tool call,
# add the remaining text and token ids
# to delta from previous
if not added_content_delta_arr[i]:
added_content_delta_arr[i] = True
previous_text = ""
previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request))
# when only tool calls
elif tool_choice_auto:
assert tool_parser is not None
delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
......@@ -507,23 +599,9 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids,
request=request))
# update the previous values for the next iteration
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# reasoning_content cannot be enabled with tool_choice.
# If it is, the tool_choice will be used instead.
# when only reasoning
elif self.enable_reasoning:
# handle reasoning_content delta
assert reasoning_parser is not None
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list(
output.token_ids)
delta_message = (reasoning_parser.
extract_reasoning_content_streaming(
previous_text,
......@@ -533,15 +611,17 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids,
output.token_ids,
))
# update the previous values for the next iteration
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# handle streaming just a content delta
else:
delta_message = DeltaMessage(content=delta_text)
# update the previous values for the next iteration
if tool_choice_auto or should_stream_with_reasoning_parsing:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# set the previous values for the next iteration
previous_num_tokens[i] += len(output.token_ids)
......@@ -739,24 +819,24 @@ class OpenAIServingChat(OpenAIServing):
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
# If the reasoning parser is enabled,
# tool calls are extracted exclusively from the content.
reasoning_content, content = (
reasoning_parser.extract_reasoning_content(
output.text, request=request))
if reasoning_content:
message = ChatMessage(role=role,
content=content,
reasoning_content=reasoning_content)
else:
message = ChatMessage(role=role, content=output.text)
else:
reasoning_content = None
content = output.text
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
elif (not self.enable_auto_tools
or not self.tool_parser) and not isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam):
message = ChatMessage(role=role, content=output.text)
if (not self.enable_auto_tools
or not self.tool_parser) and not isinstance(
request.tool_choice,
ChatCompletionNamedToolChoiceParam):
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
# if the request uses tools and specified a tool choice
elif request.tool_choice and type(
......@@ -766,18 +846,21 @@ class OpenAIServingChat(OpenAIServing):
tokenizer, MistralTokenizer) else ToolCall
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content="",
tool_calls=[
tool_call_class(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text))
arguments=content))
])
# if the request doesn't use tool choice
# OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none":
message = ChatMessage(role=role, content=output.text)
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
# handle when there are tools and tool choice is auto
elif request.tools and (
......@@ -792,20 +875,23 @@ class OpenAIServingChat(OpenAIServing):
return self.create_error_response(str(e))
tool_call_info = tool_parser.extract_tool_calls(
output.text, request=request)
content if content is not None else "", request=request)
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=tool_call_info.content,
tool_calls=tool_call_info.tool_calls)
else:
# FOR NOW make it a chat message; we will have to detect
# the type to make it later.
message = ChatMessage(role=role, content=output.text)
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
# undetermined case that is still important to handle
else:
......@@ -813,7 +899,9 @@ class OpenAIServingChat(OpenAIServing):
"Error in chat_completion_full_generator - cannot determine"
" if tools should be extracted. Returning a standard chat "
"completion.")
message = ChatMessage(role=role, content=output.text)
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
choice_data = ChatCompletionResponseChoice(
index=output.index,
......
# SPDX-License-Identifier: Apache-2.0
import hashlib
import os
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Optional
......@@ -40,11 +41,8 @@ if TYPE_CHECKING:
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_MOE_PREPACK: bool = True
VLLM_OPENVINO_DEVICE: str = "CPU"
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False
......@@ -74,10 +72,13 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
Q_SCALE_CONSTANT: int = 200
K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100
VLLM_SERVER_DEV_MODE: bool = False
......@@ -94,6 +95,7 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
def get_default_cache_root():
......@@ -126,7 +128,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default),
# rocm, neuron, cpu, openvino]
# rocm, neuron, cpu]
"VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"),
......@@ -353,28 +355,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CPU_MOE_PREPACK":
lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))),
# OpenVINO device selection
# default is CPU
"VLLM_OPENVINO_DEVICE":
lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(),
# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_OPENVINO_KVCACHE_SPACE", "0")),
# OpenVINO KV cache precision
# default is bf16 if natively supported by platform, otherwise f16
# To enable KV cache compression, please, explicitly specify u8
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION":
lambda: os.getenv("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", None),
# Enables weights compression during model export via HF Optimum
# default is False
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS":
lambda:
(os.environ.get("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", "0").lower() in
("on", "true", "1")),
# If the env var is set, then all workers will execute as separate
# processes from the engine, and we use the same mechanism to trigger
# execution on all workers.
......@@ -444,6 +424,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_XLA_CACHE_PATH",
os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
)),
# If set, assert on XLA recompilation after each execution step.
"VLLM_XLA_CHECK_RECOMPILATION":
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))),
"VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),
......@@ -521,16 +505,31 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),
# Disable aiter ops unless specifically enabled.
# Acts as a parent switch to enable the rest of the other operations.
"VLLM_ROCM_USE_AITER":
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),
# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
("true", "1")),
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
# Divisor for dynamic value scale factor calculation for FP8 KV Cache
"V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
......@@ -618,6 +617,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# an environment with potentially malicious users.
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
}
# end-env-vars-definition
......@@ -648,3 +652,43 @@ def set_vllm_use_v1(use_v1: bool):
"explicitly by the user. Please raise this as a Github "
"Issue and explicitly set VLLM_USE_V1=0 or 1.")
os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0"
def compute_hash() -> str:
"""
WARNING: Whenever a new key is added to this environment
variables, ensure that it is included in the factors list if
it affects the computation graph. For example, different values
of VLLM_PP_LAYER_PARTITION will generate different computation
graphs, so it is included in the factors list. The env vars that
affect the choice of different kernels or attention backends should
also be included in the factors list.
"""
factors: list[Any] = []
# summarize environment variables
def factorize(name: str):
if __getattr__(name):
factors.append(__getattr__(name))
else:
factors.append("None")
# The values of envs may affects the computation graph.
# TODO(DefTruth): hash all environment variables?
# for key in environment_variables:
# factorize(key)
environment_variables_to_hash = [
"VLLM_PP_LAYER_PARTITION",
"VLLM_MLA_DISABLE",
"VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK",
"VLLM_DP_SIZE",
]
for key in environment_variables_to_hash:
if key in environment_variables:
factorize(key)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -16,7 +16,7 @@ import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
from vllm.utils import _maybe_force_spawn, get_mp_context, run_method
logger = init_logger(__name__)
......@@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config):
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""
_check_multiproc_method()
_maybe_force_spawn()
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
......
......@@ -340,6 +340,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
and v not in self.non_carry_over_env_vars
]
env_vars_to_copy.extend(current_platform.additional_env_vars)
# Copy existing env vars to each worker's args
for args in all_args_to_update_environment_variables:
# TODO: refactor platform-specific env vars
......@@ -559,6 +561,15 @@ class RayDistributedExecutor(DistributedExecutorBase):
envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution,
# i.e., the distributed execution that includes model forward runs and
# intermediate tensor communications, in the case of vllm.
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
logger.info("RAY_CGRAPH_get_timeout is set to %s",
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
with InputNode() as input_data:
# Example DAG: PP=2, TP=4
#
......
......@@ -17,7 +17,7 @@ from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
......@@ -284,8 +284,9 @@ def initialize_ray_cluster(
assert_ray_available()
from vllm.platforms import current_platform
# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():
if ray.is_initialized():
logger.info("Ray is already initialized. Skipping Ray initialization.")
elif current_platform.is_rocm() or current_platform.is_xpu():
# Try to connect existing ray instance and create a new one if not found
try:
ray.init("auto", ignore_reinit_error=True)
......@@ -299,19 +300,21 @@ def initialize_ray_cluster(
else:
ray.init(address=ray_address, ignore_reinit_error=True)
if parallel_config.placement_group:
# Placement group is already set.
return
device_str = current_platform.ray_device_key
if not device_str:
raise ValueError(
f"current platform {current_platform.device_name} does not "
"support ray.")
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
# Create or get the placement group for worker processes
if parallel_config.placement_group:
current_placement_group = parallel_config.placement_group
else:
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
logger.info("Using the existing placement group")
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
......@@ -331,6 +334,8 @@ def initialize_ray_cluster(
f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.")
else:
logger.info("No current placement group found. "
"Creating a new placement group.")
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
# Log a warning message and delay resource allocation failure response.
# Avoid immediate rejection to allow user-initiated placement group
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from vllm import envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
device_capability = current_platform.get_device_capability()
assert device_capability is not None
# 1. default version depending on platform
fa_version = 3 if (device_capability.major == 9
and is_fa_version_supported(3)) else 2
# 2. override if passed by environment
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform "
"defaulting to FA version 2.")
fa_version = 2
if requires_alibi and fa_version == 3:
logger.warning_once("Cannot use FA version 3 with ALiBi, "
"defaulting to FA version 2.")
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None
......@@ -182,7 +182,6 @@ class InputPreprocessor:
def _tokenize_prompt(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> list[int]:
"""
......@@ -202,15 +201,13 @@ class InputPreprocessor:
"do_lower_case", False)):
prompt = prompt.lower()
return tokenizer.encode(request_id=request_id,
prompt=prompt,
return tokenizer.encode(prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> list[int]:
"""Async version of :meth:`_tokenize_prompt`."""
......@@ -222,7 +219,6 @@ class InputPreprocessor:
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return await tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
......@@ -309,7 +305,6 @@ class InputPreprocessor:
def _prompt_to_llm_inputs(
self,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
......@@ -318,7 +313,6 @@ class InputPreprocessor:
Arguments:
* request_id
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
* return_mm_hashes: whether to return multimodal hashes
......@@ -333,7 +327,6 @@ class InputPreprocessor:
prompt_text = parsed["content"]
prompt_token_ids = self._tokenize_prompt(
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
......@@ -384,7 +377,6 @@ class InputPreprocessor:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
......@@ -400,7 +392,6 @@ class InputPreprocessor:
async def _prompt_to_llm_inputs_async(
self,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
......@@ -411,7 +402,6 @@ class InputPreprocessor:
prompt_text = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
......@@ -460,7 +450,6 @@ class InputPreprocessor:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
......@@ -560,7 +549,6 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
request_id: str,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
......@@ -587,7 +575,6 @@ class InputPreprocessor:
Arguments:
* prompt: an input prompt
* request_id
Returns:
......@@ -598,16 +585,11 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt):
encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"],
request_id=request_id,
)
prompt["encoder_prompt"])
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
else:
decoder_inputs = self._prompt_to_llm_inputs(
decoder_input,
request_id=request_id,
)
decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
......@@ -616,10 +598,7 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
inputs = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
inputs = self._prompt_to_llm_inputs(prompt)
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
# Encoder-Decoder Multimodal model
......@@ -636,7 +615,6 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async(
self,
prompt: PromptType,
request_id: str,
) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_inputs: SingletonInputs
......@@ -644,18 +622,13 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"],
request_id=request_id,
)
prompt["encoder_prompt"])
if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_inputs = await encoder_task
decoder_inputs = None
else:
decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
request_id=request_id,
)
decoder_task = self._prompt_to_llm_inputs_async(decoder_input)
encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)
......@@ -668,10 +641,7 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
inputs = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
)
inputs = await self._prompt_to_llm_inputs_async(prompt)
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
# Encoder-Decoder Multimodal model
......@@ -704,7 +674,6 @@ class InputPreprocessor:
def _process_decoder_only_prompt(
self,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
......@@ -716,7 +685,6 @@ class InputPreprocessor:
Arguments:
* prompt: input prompt
* request_id
* lora_request
* prompt_adapter_request
* return_mm_hashes
......@@ -728,7 +696,6 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
......@@ -741,7 +708,6 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async(
self,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
......@@ -749,7 +715,6 @@ class InputPreprocessor:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
......@@ -762,7 +727,6 @@ class InputPreprocessor:
def preprocess(
self,
prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
......@@ -774,10 +738,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(
prompt,
request_id=request_id,
)
return self._process_encoder_decoder_prompt(prompt)
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
......@@ -786,7 +747,6 @@ class InputPreprocessor:
# Decoder-only operation
return self._process_decoder_only_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes,
......@@ -795,7 +755,6 @@ class InputPreprocessor:
async def preprocess_async(
self,
prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
......@@ -807,10 +766,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(
prompt,
request_id=request_id,
)
return await self._process_encoder_decoder_prompt_async(prompt)
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
......@@ -819,7 +775,6 @@ class InputPreprocessor:
# Decoder-only operation
return await self._process_decoder_only_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes,
......
......@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
import torch
import vllm.envs as envs
from vllm.lora.layers import LoRAMapping
from vllm.triton_utils import HAS_TRITON
......@@ -42,8 +43,15 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras,
max_num_batched_tokens,
device=device)
# When cudagraph capture size is greater than max_num_seqs (max_batches,
# here), V0 captures the graph as if max_num_seqs is set to
# the capture size.
# V1 doesn't have this problem and always respects max_num_seqs.
max_num_prompts = (max_batches
if envs.VLLM_USE_V1 else max_num_batched_tokens)
self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras,
max_batches,
max_num_prompts,
device=device)
def update_metadata(
......
......@@ -79,6 +79,12 @@ def maybe_backend_fallback(
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines")
elif guided_params.json_object:
# https://github.com/mlc-ai/xgrammar/issues/256
fallback_or_error(guided_params,
"xgrammar does not support json_object.",
"guidance")
# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
......@@ -88,9 +94,9 @@ def maybe_backend_fallback(
if (guided_params.backend_name == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar
# outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params,
"outlines does not support json_object.", "xgrammar")
"outlines does not support json_object.", "guidance")
return guided_params
......@@ -122,10 +128,15 @@ async def get_guided_decoding_logits_processor(
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)
def get_local_guided_decoding_logits_processor(
......@@ -155,7 +166,13 @@ def get_local_guided_decoding_logits_processor(
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)
# SPDX-License-Identifier: Apache-2.0
from re import escape as regex_escape
import llguidance
from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
GuidanceLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
def get_local_guidance_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
"""
grm = ""
if guided_params.json:
grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
elif guided_params.json_object:
grm = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
elif guided_params.regex:
grm = llguidance.grammar_from("regex", guided_params.regex)
elif guided_params.choice:
# choice just uses regex
choices = (regex_escape(str(choice))
for choice in guided_params.choice)
choices_regex = "(" + "|".join(choices) + ")"
grm = llguidance.grammar_from("regex", choices_regex)
elif guided_params.grammar:
# this supports Lark and GBNF
grm = llguidance.grammar_from("grammar", guided_params.grammar)
if grm:
return GuidanceLogitsProcessor(grm, tokenizer)
raise ValueError("Unknown guided decoding mode")
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any, List
import llguidance
import llguidance.hf
import llguidance.torch
import torch
from transformers import PreTrainedTokenizerBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class GuidanceLogitsProcessor:
"""Base Guidance Logits Processor"""
cached_tokenizers: dict[str, Any] = {}
def __init__(
self,
grammar: str,
tokenizer: PreTrainedTokenizerBase,
) -> None:
"""Base Guidance Logits Processor
Args:
grammar (str)
grammar to guide the generation
tokenizer (PreTrainedTokenizerBase)
model's tokenizer
"""
self.grammar = grammar
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer.name_or_path
self.new_sampling = False
self.initialized = False
def _initialize(self):
if self.initialized:
return
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
None)
if ll_tokenizer is None:
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer
self.ll_tokenizer = ll_tokenizer
self.ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
self.grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
# create reusable bitmask
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size)
self.initialized = True
def __call__(
self,
input_ids: List[int],
scores: torch.Tensor,
) -> torch.Tensor:
# we initialize the guidance model here
# to avoid pickling ll_tokenizer and ll_interpreter
self._initialize()
if self.new_sampling and len(input_ids) > 0:
self.ll_matcher.consume_token(input_ids[-1])
err = self.ll_matcher.get_error()
if err:
logger.warning("Error in LLMatcher: %s", err)
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
0)
llguidance.torch.apply_token_bitmask_inplace(
scores, self.bitmask.to(scores.device))
self.new_sampling = True
return scores
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment