Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53076d70
Commit
53076d70
authored
Mar 24, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.2' into v0.8.2-ori
parents
322a0be6
9c5c81b0
Changes
219
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
533 additions
and
177 deletions
+533
-177
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-3
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+4
-3
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+4
-3
vllm/engine/protocol.py
vllm/engine/protocol.py
+4
-6
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+4
-3
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+10
-3
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+3
-7
vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
...ypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
+35
-0
vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
.../openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
+13
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+138
-50
vllm/envs.py
vllm/envs.py
+72
-28
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+2
-2
vllm/executor/ray_distributed_executor.py
vllm/executor/ray_distributed_executor.py
+11
-0
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+14
-9
vllm/fa_utils.py
vllm/fa_utils.py
+48
-0
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+9
-54
vllm/lora/punica_wrapper/punica_gpu.py
vllm/lora/punica_wrapper/punica_gpu.py
+9
-1
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+22
-5
vllm/model_executor/guided_decoding/guidance_decoding.py
vllm/model_executor/guided_decoding/guidance_decoding.py
+44
-0
vllm/model_executor/guided_decoding/guidance_logits_processors.py
...el_executor/guided_decoding/guidance_logits_processors.py
+85
-0
No files found.
vllm/engine/llm_engine.py
View file @
53076d70
...
...
@@ -783,7 +783,6 @@ class LLMEngine:
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
...
...
@@ -955,12 +954,12 @@ class LLMEngine:
"""
return
self
.
scheduler
[
virtual_engine
].
has_unfinished_seqs
()
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
bool
:
"""Reset prefix cache for all devices."""
success
=
True
for
scheduler
in
self
.
scheduler
:
success
=
success
and
scheduler
.
reset_prefix_cache
()
success
=
success
and
scheduler
.
reset_prefix_cache
(
device
)
return
success
@
staticmethod
...
...
vllm/engine/multiprocessing/__init__.py
View file @
53076d70
...
...
@@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
deprecate_kwargs
from
vllm.utils
import
Device
,
deprecate_kwargs
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
...
...
@@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE
=
2
class
RPCResetPrefixCacheRequest
(
Enum
):
RESET_PREFIX_CACHE
=
1
@
dataclass
class
RPCResetPrefixCacheRequest
:
device
:
Device
class
RPCSleepRequest
(
Enum
):
...
...
vllm/engine/multiprocessing/client.py
View file @
53076d70
...
...
@@ -47,7 +47,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
deprecate_kwargs
from
vllm.utils
import
Device
,
deprecate_kwargs
logger
=
init_logger
(
__name__
)
...
...
@@ -684,11 +684,12 @@ class MQLLMEngineClient(EngineClient):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUProfileRequest
.
STOP_PROFILE
,
socket
=
self
.
input_socket
)
async
def
reset_prefix_cache
(
self
)
->
None
:
async
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
None
:
"""Reset the prefix cache"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCResetPrefixCacheRequest
.
RESET_PREFIX_CACHE
,
request
=
RPCResetPrefixCacheRequest
(
device
)
,
socket
=
self
.
input_socket
)
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
...
...
vllm/engine/protocol.py
View file @
53076d70
...
...
@@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
collect_from_async_generator
,
random_uuid
from
vllm.utils
import
Device
,
collect_from_async_generator
,
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -81,10 +81,7 @@ class EngineClient(ABC):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
NotImplementedError
else
:
processed_inputs
=
preprocessor
.
_prompt_to_llm_inputs
(
prompt
,
request_id
=
request_id
,
)
processed_inputs
=
preprocessor
.
_prompt_to_llm_inputs
(
prompt
)
prompt_token_ids
=
processed_inputs
[
"prompt_token_ids"
]
prompt_text
=
processed_inputs
.
get
(
"prompt"
)
...
...
@@ -274,7 +271,8 @@ class EngineClient(ABC):
...
@
abstractmethod
async
def
reset_prefix_cache
(
self
)
->
None
:
async
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
None
:
"""Reset the prefix cache"""
...
...
...
vllm/entrypoints/llm.py
View file @
53076d70
...
...
@@ -42,7 +42,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
)
logger
=
init_logger
(
__name__
)
...
...
@@ -1187,8 +1188,8 @@ class LLM:
def
stop_profile
(
self
)
->
None
:
self
.
llm_engine
.
stop_profile
()
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
llm_engine
.
reset_prefix_cache
()
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
bool
:
return
self
.
llm_engine
.
reset_prefix_cache
(
device
)
def
sleep
(
self
,
level
:
int
=
1
):
"""
...
...
vllm/entrypoints/openai/api_server.py
View file @
53076d70
...
...
@@ -85,7 +85,7 @@ from vllm.logger import init_logger
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
from
vllm.utils
import
(
Device
,
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
is_valid_ipv6_address
,
set_ulimit
)
from
vllm.version
import
__version__
as
VLLM_VERSION
...
...
@@ -677,8 +677,12 @@ if envs.VLLM_SERVER_DEV_MODE:
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
logger
.
info
(
"Resetting prefix cache..."
)
await
engine_client
(
raw_request
).
reset_prefix_cache
()
device
=
None
device_str
=
raw_request
.
query_params
.
get
(
"device"
)
if
device_str
is
not
None
:
device
=
Device
[
device_str
.
upper
()]
logger
.
info
(
"Resetting prefix cache with specific %s..."
,
str
(
device
))
await
engine_client
(
raw_request
).
reset_prefix_cache
(
device
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/sleep"
)
...
...
@@ -1032,6 +1036,9 @@ async def run_server(args, **uvicorn_kwargs) -> None:
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
uvicorn_log_level
,
# NOTE: When the 'disable_uvicorn_access_log' value is True,
# no access log will be output.
access_log
=
not
args
.
disable_uvicorn_access_log
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
...
...
vllm/entrypoints/openai/cli_args.py
View file @
53076d70
...
...
@@ -89,6 +89,9 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default
=
"info"
,
choices
=
[
'debug'
,
'info'
,
'warning'
,
'error'
,
'critical'
,
'trace'
],
help
=
"Log level for uvicorn."
)
parser
.
add_argument
(
"--disable-uvicorn-access-log"
,
action
=
"store_true"
,
help
=
"Disable uvicorn access log."
)
parser
.
add_argument
(
"--allow-credentials"
,
action
=
"store_true"
,
help
=
"Allow credentials."
)
...
...
@@ -286,13 +289,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
raise
TypeError
(
"Error: --enable-reasoning requires "
"--reasoning-parser"
)
# Ref https://api-docs.deepseek.com/guides/reasoning_model
# tool call and reasoning cannot be enabled at the same time.
if
args
.
enable_auto_tool_choice
and
args
.
enable_reasoning
:
raise
TypeError
(
"Error: --enable-auto-tool-choice and "
"--enable-reasoning cannot be enabled at the same time"
)
def
create_parser_for_docs
()
->
FlexibleArgumentParser
:
parser_for_docs
=
FlexibleArgumentParser
(
...
...
vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
import
os
from
abc
import
abstractmethod
from
collections.abc
import
Sequence
from
functools
import
cached_property
from
typing
import
Callable
,
Optional
,
Union
...
...
@@ -76,6 +77,40 @@ class ReasoningParser:
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!"
)
# TODO: need to rebase by PR #14428
@
abstractmethod
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
"""
Check if the reasoning content ends in the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
raise
NotImplementedError
(
"AbstractReasoningParser.is_reasoning_end has"
"not been implemented!"
)
# TODO: need to rebase by PR #14428
@
abstractmethod
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
raise
NotImplementedError
(
"AbstractReasoningParser.extract_content_ids has"
" not been implemented!"
)
class
ReasoningParserManager
:
reasoning_parsers
:
dict
[
str
,
type
]
=
{}
...
...
vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
View file @
53076d70
...
...
@@ -45,6 +45,19 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!"
)
# TODO: need to rebase by PR #14428
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
return
self
.
think_end_token_id
in
input_ids
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
"""
Extract the content after the end tokens
"""
if
self
.
think_end_token_id
not
in
input_ids
[:
-
1
]:
return
[]
else
:
return
input_ids
[
input_ids
.
index
(
self
.
think_end_token_id
)
+
1
:]
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
53076d70
...
...
@@ -328,6 +328,9 @@ class OpenAIServingChat(OpenAIServing):
# These are only required in "auto" tool choice case
previous_texts
=
[
""
]
*
num_choices
all_previous_token_ids
=
[[]]
*
num_choices
# For reasoning parser and tool call all enabled
added_content_delta_arr
=
[
False
]
*
num_choices
reasoning_end_arr
=
[
False
]
*
num_choices
else
:
previous_texts
,
all_previous_token_ids
=
None
,
None
...
...
@@ -477,27 +480,116 @@ class OpenAIServingChat(OpenAIServing):
delta_message
:
Optional
[
DeltaMessage
]
# handle streaming deltas for tools with named tool_choice
if
tool_choice_function_name
:
delta_message
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
function
=
DeltaFunctionCall
(
name
=
tool_choice_function_name
,
arguments
=
delta_text
),
index
=
i
)
])
# handle streaming deltas for tools with "auto" tool choice
elif
tool_choice_auto
:
# just update previous_texts and previous_token_ids
if
tool_choice_auto
or
should_stream_with_reasoning_parsing
:
assert
previous_texts
is
not
None
assert
all_previous_token_ids
is
not
None
assert
tool_parser
is
not
None
#TODO optimize manipulation of these lists
previous_text
=
previous_texts
[
i
]
previous_token_ids
=
all_previous_token_ids
[
i
]
current_text
=
previous_text
+
delta_text
current_token_ids
=
previous_token_ids
+
list
(
output
.
token_ids
)
# handle streaming deltas for tools with named tool_choice
if
tool_choice_function_name
:
if
(
self
.
enable_reasoning
and
not
reasoning_parser
.
is_reasoning_end
(
previous_token_ids
)):
assert
reasoning_parser
is
not
None
delta_message
=
(
reasoning_parser
.
extract_reasoning_content_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
output
.
token_ids
,
))
# When encountering think end id in delta_token_ids,
# process the `content`. Only keep 'content',
# remove 'reasoning_content'
if
reasoning_parser
.
is_reasoning_end
(
list
(
output
.
token_ids
)):
if
delta_message
and
delta_message
.
content
:
# This need to be added to next `delta_text`
current_text
=
delta_message
.
content
delta_message
.
content
=
None
else
:
current_text
=
""
else
:
# Just to add remaining `content`
if
self
.
enable_reasoning
:
delta_text
=
previous_text
+
delta_text
current_text
=
""
delta_message
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
function
=
DeltaFunctionCall
(
name
=
tool_choice_function_name
,
arguments
=
delta_text
),
index
=
i
)
])
# handle streaming deltas for tools with "auto" tool choice
# and reasoning parser
elif
tool_choice_auto
and
self
.
enable_reasoning
:
assert
tool_parser
is
not
None
assert
reasoning_parser
is
not
None
assert
added_content_delta_arr
is
not
None
assert
reasoning_end_arr
is
not
None
if
not
reasoning_end_arr
[
i
]:
delta_message
=
(
reasoning_parser
.
extract_reasoning_content_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
output
.
token_ids
,
))
# When encountering think end id in delta_token_ids,
# set reasoning status to end.
# Remove the text and token ids related
# to 'reasoning_content'.
if
reasoning_parser
.
is_reasoning_end
(
list
(
output
.
token_ids
)):
reasoning_end_arr
[
i
]
=
True
current_token_ids
=
\
reasoning_parser
.
extract_content_ids
(
list
(
output
.
token_ids
))
if
delta_message
and
delta_message
.
content
:
current_text
=
delta_message
.
content
delta_message
.
content
=
None
else
:
current_text
=
""
# handle tool calls only after reasoning is done,
else
:
delta_token_ids
=
list
(
output
.
token_ids
)
# First time to tool call,
# add the remaining text and token ids
# to delta from previous
if
not
added_content_delta_arr
[
i
]:
added_content_delta_arr
[
i
]
=
True
previous_text
=
""
previous_token_ids
=
[]
delta_text
=
current_text
delta_token_ids
=
current_token_ids
delta_message
=
(
tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
delta_text
,
previous_token_ids
=
previous_token_ids
,
current_token_ids
=
current_token_ids
,
delta_token_ids
=
delta_token_ids
,
request
=
request
))
# when only tool calls
elif
tool_choice_auto
:
assert
tool_parser
is
not
None
delta_message
=
(
tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
previous_text
,
...
...
@@ -507,23 +599,9 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids
=
current_token_ids
,
delta_token_ids
=
output
.
token_ids
,
request
=
request
))
# update the previous values for the next iteration
previous_texts
[
i
]
=
current_text
all_previous_token_ids
[
i
]
=
current_token_ids
# reasoning_content cannot be enabled with tool_choice.
# If it is, the tool_choice will be used instead.
# when only reasoning
elif
self
.
enable_reasoning
:
# handle reasoning_content delta
assert
reasoning_parser
is
not
None
assert
previous_texts
is
not
None
assert
all_previous_token_ids
is
not
None
previous_text
=
previous_texts
[
i
]
previous_token_ids
=
all_previous_token_ids
[
i
]
current_text
=
previous_text
+
delta_text
current_token_ids
=
previous_token_ids
+
list
(
output
.
token_ids
)
delta_message
=
(
reasoning_parser
.
extract_reasoning_content_streaming
(
previous_text
,
...
...
@@ -533,15 +611,17 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids
,
output
.
token_ids
,
))
# update the previous values for the next iteration
previous_texts
[
i
]
=
current_text
all_previous_token_ids
[
i
]
=
current_token_ids
# handle streaming just a content delta
else
:
delta_message
=
DeltaMessage
(
content
=
delta_text
)
# update the previous values for the next iteration
if
tool_choice_auto
or
should_stream_with_reasoning_parsing
:
assert
previous_texts
is
not
None
assert
all_previous_token_ids
is
not
None
previous_texts
[
i
]
=
current_text
all_previous_token_ids
[
i
]
=
current_token_ids
# set the previous values for the next iteration
previous_num_tokens
[
i
]
+=
len
(
output
.
token_ids
)
...
...
@@ -739,24 +819,24 @@ class OpenAIServingChat(OpenAIServing):
except
RuntimeError
as
e
:
logger
.
exception
(
"Error in reasoning parser creation."
)
return
self
.
create_error_response
(
str
(
e
))
# If the reasoning parser is enabled,
# tool calls are extracted exclusively from the content.
reasoning_content
,
content
=
(
reasoning_parser
.
extract_reasoning_content
(
output
.
text
,
request
=
request
))
if
reasoning_content
:
message
=
ChatMessage
(
role
=
role
,
content
=
content
,
reasoning_content
=
reasoning_content
)
else
:
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
)
else
:
reasoning_content
=
None
content
=
output
.
text
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
elif
(
not
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
not
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
):
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
)
if
(
not
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
not
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
):
message
=
ChatMessage
(
role
=
role
,
reasoning_content
=
reasoning_content
,
content
=
content
)
# if the request uses tools and specified a tool choice
elif
request
.
tool_choice
and
type
(
...
...
@@ -766,18 +846,21 @@ class OpenAIServingChat(OpenAIServing):
tokenizer
,
MistralTokenizer
)
else
ToolCall
message
=
ChatMessage
(
role
=
role
,
reasoning_content
=
reasoning_content
,
content
=
""
,
tool_calls
=
[
tool_call_class
(
function
=
FunctionCall
(
name
=
request
.
tool_choice
.
function
.
name
,
arguments
=
output
.
te
x
t
))
arguments
=
con
te
n
t
))
])
# if the request doesn't use tool choice
# OR specifies to not use a tool
elif
not
request
.
tool_choice
or
request
.
tool_choice
==
"none"
:
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
)
message
=
ChatMessage
(
role
=
role
,
reasoning_content
=
reasoning_content
,
content
=
content
)
# handle when there are tools and tool choice is auto
elif
request
.
tools
and
(
...
...
@@ -792,20 +875,23 @@ class OpenAIServingChat(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
tool_call_info
=
tool_parser
.
extract_tool_calls
(
output
.
text
,
request
=
request
)
content
if
content
is
not
None
else
""
,
request
=
request
)
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called
=
tool_call_info
.
tools_called
if
tool_call_info
.
tools_called
:
message
=
ChatMessage
(
role
=
role
,
reasoning_content
=
reasoning_content
,
content
=
tool_call_info
.
content
,
tool_calls
=
tool_call_info
.
tool_calls
)
else
:
# FOR NOW make it a chat message; we will have to detect
# the type to make it later.
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
)
message
=
ChatMessage
(
role
=
role
,
reasoning_content
=
reasoning_content
,
content
=
content
)
# undetermined case that is still important to handle
else
:
...
...
@@ -813,7 +899,9 @@ class OpenAIServingChat(OpenAIServing):
"Error in chat_completion_full_generator - cannot determine"
" if tools should be extracted. Returning a standard chat "
"completion."
)
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
)
message
=
ChatMessage
(
role
=
role
,
reasoning_content
=
reasoning_content
,
content
=
content
)
choice_data
=
ChatCompletionResponseChoice
(
index
=
output
.
index
,
...
...
vllm/envs.py
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
import
hashlib
import
os
import
tempfile
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
...
...
@@ -40,11 +41,8 @@ if TYPE_CHECKING:
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
VLLM_CPU_MOE_PREPACK
:
bool
=
True
VLLM_OPENVINO_DEVICE
:
str
=
"CPU"
VLLM_OPENVINO_KVCACHE_SPACE
:
int
=
0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
:
Optional
[
str
]
=
None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
:
bool
=
False
VLLM_XLA_CACHE_PATH
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"xla_cache"
)
VLLM_XLA_CHECK_RECOMPILATION
:
bool
=
False
VLLM_FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_USE_RAY_SPMD_WORKER
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
...
...
@@ -74,10 +72,13 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_DISABLED_KERNELS
:
list
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
True
VLLM_ROCM_USE_AITER
:
bool
=
False
VLLM_ROCM_USE_AITER_RMSNORM
:
bool
=
True
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
VLLM_LOG_BATCHSIZE_INTERVAL
:
float
=
-
1
VLLM_DISABLE_COMPILE_CACHE
:
bool
=
False
Q_SCALE_CONSTANT
:
int
=
200
K_SCALE_CONSTANT
:
int
=
200
V_SCALE_CONSTANT
:
int
=
100
VLLM_SERVER_DEV_MODE
:
bool
=
False
...
...
@@ -94,6 +95,7 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT
:
int
=
0
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
VLLM_V0_USE_OUTLINES_CACHE
:
bool
=
False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -126,7 +128,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default),
# rocm, neuron, cpu
, openvino
]
# rocm, neuron, cpu]
"VLLM_TARGET_DEVICE"
:
lambda
:
os
.
getenv
(
"VLLM_TARGET_DEVICE"
,
"cuda"
),
...
...
@@ -353,28 +355,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CPU_MOE_PREPACK"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_CPU_MOE_PREPACK"
,
"1"
))),
# OpenVINO device selection
# default is CPU
"VLLM_OPENVINO_DEVICE"
:
lambda
:
os
.
getenv
(
"VLLM_OPENVINO_DEVICE"
,
"CPU"
).
upper
(),
# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_OPENVINO_KVCACHE_SPACE"
,
"0"
)),
# OpenVINO KV cache precision
# default is bf16 if natively supported by platform, otherwise f16
# To enable KV cache compression, please, explicitly specify u8
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION"
:
lambda
:
os
.
getenv
(
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION"
,
None
),
# Enables weights compression during model export via HF Optimum
# default is False
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS"
,
"0"
).
lower
()
in
(
"on"
,
"true"
,
"1"
)),
# If the env var is set, then all workers will execute as separate
# processes from the engine, and we use the same mechanism to trigger
# execution on all workers.
...
...
@@ -444,6 +424,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_XLA_CACHE_PATH"
,
os
.
path
.
join
(
get_default_cache_root
(),
"vllm"
,
"xla_cache"
),
)),
# If set, assert on XLA recompilation after each execution step.
"VLLM_XLA_CHECK_RECOMPILATION"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_XLA_CHECK_RECOMPILATION"
,
"0"
))),
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"32768"
)),
...
...
@@ -521,16 +505,31 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V1"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_V1"
,
"1"
))),
# Disable aiter ops unless specifically enabled.
# Acts as a parent switch to enable the rest of the other operations.
"VLLM_ROCM_USE_AITER"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_RMSNORM"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ROCM_FP8_PADDING"
,
"1"
))),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT"
:
lambda
:
int
(
os
.
getenv
(
"Q_SCALE_CONSTANT"
,
"200"
)),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT"
:
lambda
:
int
(
os
.
getenv
(
"K_SCALE_CONSTANT"
,
"200"
)),
# Divisor for dynamic value scale factor calculation for FP8 KV Cache
"V_SCALE_CONSTANT"
:
lambda
:
int
(
os
.
getenv
(
"V_SCALE_CONSTANT"
,
"100"
)),
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"1"
))),
...
...
@@ -618,6 +617,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# an environment with potentially malicious users.
"VLLM_V0_USE_OUTLINES_CACHE"
:
lambda
:
os
.
environ
.
get
(
"VLLM_V0_USE_OUTLINES_CACHE"
,
"0"
)
==
"1"
,
# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"
:
lambda
:
bool
(
int
(
os
.
environ
[
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"
]))
if
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"
in
os
.
environ
else
None
,
}
# end-env-vars-definition
...
...
@@ -648,3 +652,43 @@ def set_vllm_use_v1(use_v1: bool):
"explicitly by the user. Please raise this as a Github "
"Issue and explicitly set VLLM_USE_V1=0 or 1."
)
os
.
environ
[
"VLLM_USE_V1"
]
=
"1"
if
use_v1
else
"0"
def
compute_hash
()
->
str
:
"""
WARNING: Whenever a new key is added to this environment
variables, ensure that it is included in the factors list if
it affects the computation graph. For example, different values
of VLLM_PP_LAYER_PARTITION will generate different computation
graphs, so it is included in the factors list. The env vars that
affect the choice of different kernels or attention backends should
also be included in the factors list.
"""
factors
:
list
[
Any
]
=
[]
# summarize environment variables
def
factorize
(
name
:
str
):
if
__getattr__
(
name
):
factors
.
append
(
__getattr__
(
name
))
else
:
factors
.
append
(
"None"
)
# The values of envs may affects the computation graph.
# TODO(DefTruth): hash all environment variables?
# for key in environment_variables:
# factorize(key)
environment_variables_to_hash
=
[
"VLLM_PP_LAYER_PARTITION"
,
"VLLM_MLA_DISABLE"
,
"VLLM_USE_TRITON_FLASH_ATTN"
,
"VLLM_USE_TRITON_AWQ"
,
"VLLM_DP_RANK"
,
"VLLM_DP_SIZE"
,
]
for
key
in
environment_variables_to_hash
:
if
key
in
environment_variables
:
factorize
(
key
)
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
vllm/executor/multiproc_worker_utils.py
View file @
53076d70
...
...
@@ -16,7 +16,7 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
_
check_multiproc_method
,
get_mp_context
,
run_method
from
vllm.utils
import
_
maybe_force_spawn
,
get_mp_context
,
run_method
logger
=
init_logger
(
__name__
)
...
...
@@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config):
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""
_
check_multiproc_method
()
_
maybe_force_spawn
()
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
...
...
vllm/executor/ray_distributed_executor.py
View file @
53076d70
...
...
@@ -340,6 +340,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
and
v
not
in
self
.
non_carry_over_env_vars
]
env_vars_to_copy
.
extend
(
current_platform
.
additional_env_vars
)
# Copy existing env vars to each worker's args
for
args
in
all_args_to_update_environment_variables
:
# TODO: refactor platform-specific env vars
...
...
@@ -559,6 +561,15 @@ class RayDistributedExecutor(DistributedExecutorBase):
envs
.
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL
)
logger
.
info
(
"VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s"
,
envs
.
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM
)
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution,
# i.e., the distributed execution that includes model forward runs and
# intermediate tensor communications, in the case of vllm.
os
.
environ
.
setdefault
(
"RAY_CGRAPH_get_timeout"
,
"300"
)
# noqa: SIM112
logger
.
info
(
"RAY_CGRAPH_get_timeout is set to %s"
,
os
.
environ
[
"RAY_CGRAPH_get_timeout"
])
# noqa: SIM112
with
InputNode
()
as
input_data
:
# Example DAG: PP=2, TP=4
#
...
...
vllm/executor/ray_utils.py
View file @
53076d70
...
...
@@ -17,7 +17,7 @@ from vllm.utils import get_ip
from
vllm.worker.worker_base
import
WorkerWrapperBase
if
TYPE_CHECKING
:
from
vllm.v1.core.sched
uler
import
SchedulerOutput
from
vllm.v1.core.sched
.output
import
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
logger
=
init_logger
(
__name__
)
...
...
@@ -284,8 +284,9 @@ def initialize_ray_cluster(
assert_ray_available
()
from
vllm.platforms
import
current_platform
# Connect to a ray cluster.
if
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
if
ray
.
is_initialized
():
logger
.
info
(
"Ray is already initialized. Skipping Ray initialization."
)
elif
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
# Try to connect existing ray instance and create a new one if not found
try
:
ray
.
init
(
"auto"
,
ignore_reinit_error
=
True
)
...
...
@@ -299,19 +300,21 @@ def initialize_ray_cluster(
else
:
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
if
parallel_config
.
placement_group
:
# Placement group is already set.
return
device_str
=
current_platform
.
ray_device_key
if
not
device_str
:
raise
ValueError
(
f
"current platform
{
current_platform
.
device_name
}
does not "
"support ray."
)
# Create placement group for worker processes
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
# Create or get the placement group for worker processes
if
parallel_config
.
placement_group
:
current_placement_group
=
parallel_config
.
placement_group
else
:
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
if
current_placement_group
:
logger
.
info
(
"Using the existing placement group"
)
# We are in a placement group
bundles
=
current_placement_group
.
bundle_specs
# Verify that we can use the placement group.
...
...
@@ -331,6 +334,8 @@ def initialize_ray_cluster(
f
"Required number of devices:
{
parallel_config
.
world_size
}
. "
f
"Total number of devices:
{
device_bundles
}
."
)
else
:
logger
.
info
(
"No current placement group found. "
"Creating a new placement group."
)
num_devices_in_cluster
=
ray
.
cluster_resources
().
get
(
device_str
,
0
)
# Log a warning message and delay resource allocation failure response.
# Avoid immediate rejection to allow user-initiated placement group
...
...
vllm/fa_utils.py
0 → 100644
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
from
vllm
import
envs
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
get_flash_attn_version
(
requires_alibi
:
bool
=
False
)
->
Optional
[
int
]:
# import here to avoid circular dependencies
from
vllm.platforms
import
current_platform
try
:
from
vllm.vllm_flash_attn.flash_attn_interface
import
(
fa_version_unsupported_reason
,
is_fa_version_supported
)
device_capability
=
current_platform
.
get_device_capability
()
assert
device_capability
is
not
None
# 1. default version depending on platform
fa_version
=
3
if
(
device_capability
.
major
==
9
and
is_fa_version_supported
(
3
))
else
2
# 2. override if passed by environment
if
envs
.
VLLM_FLASH_ATTN_VERSION
is
not
None
:
assert
envs
.
VLLM_FLASH_ATTN_VERSION
in
[
2
,
3
]
fa_version
=
envs
.
VLLM_FLASH_ATTN_VERSION
# 3. fallback for unsupported combinations
if
device_capability
.
major
==
10
and
fa_version
==
3
:
logger
.
warning_once
(
"Cannot use FA version 3 on Blackwell platform "
"defaulting to FA version 2."
)
fa_version
=
2
if
requires_alibi
and
fa_version
==
3
:
logger
.
warning_once
(
"Cannot use FA version 3 with ALiBi, "
"defaulting to FA version 2."
)
fa_version
=
2
if
not
is_fa_version_supported
(
fa_version
):
logger
.
error
(
"Cannot use FA version %d is not supported due to %s"
,
fa_version
,
fa_version_unsupported_reason
(
fa_version
))
assert
is_fa_version_supported
(
fa_version
)
return
fa_version
except
(
ImportError
,
AssertionError
):
return
None
vllm/inputs/preprocess.py
View file @
53076d70
...
...
@@ -182,7 +182,6 @@ class InputPreprocessor:
def
_tokenize_prompt
(
self
,
prompt
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
],
)
->
list
[
int
]:
"""
...
...
@@ -202,15 +201,13 @@ class InputPreprocessor:
"do_lower_case"
,
False
)):
prompt
=
prompt
.
lower
()
return
tokenizer
.
encode
(
request_id
=
request_id
,
prompt
=
prompt
,
return
tokenizer
.
encode
(
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
async
def
_tokenize_prompt_async
(
self
,
prompt
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
],
)
->
list
[
int
]:
"""Async version of :meth:`_tokenize_prompt`."""
...
...
@@ -222,7 +219,6 @@ class InputPreprocessor:
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens
=
False
return
await
tokenizer
.
encode_async
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
...
...
@@ -309,7 +305,6 @@ class InputPreprocessor:
def
_prompt_to_llm_inputs
(
self
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
SingletonInputs
:
...
...
@@ -318,7 +313,6 @@ class InputPreprocessor:
Arguments:
* request_id
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
* return_mm_hashes: whether to return multimodal hashes
...
...
@@ -333,7 +327,6 @@ class InputPreprocessor:
prompt_text
=
parsed
[
"content"
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
...
...
@@ -384,7 +377,6 @@ class InputPreprocessor:
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
...
...
@@ -400,7 +392,6 @@ class InputPreprocessor:
async
def
_prompt_to_llm_inputs_async
(
self
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
SingletonInputs
:
...
...
@@ -411,7 +402,6 @@ class InputPreprocessor:
prompt_text
=
parsed
[
"content"
]
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
...
...
@@ -460,7 +450,6 @@ class InputPreprocessor:
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
...
...
@@ -560,7 +549,6 @@ class InputPreprocessor:
def
_process_encoder_decoder_prompt
(
self
,
prompt
:
PromptType
,
request_id
:
str
,
)
->
EncoderDecoderInputs
:
"""
For encoder/decoder models only:
...
...
@@ -587,7 +575,6 @@ class InputPreprocessor:
Arguments:
* prompt: an input prompt
* request_id
Returns:
...
...
@@ -598,16 +585,11 @@ class InputPreprocessor:
if
is_explicit_encoder_decoder_prompt
(
prompt
):
encoder_inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
)
prompt
[
"encoder_prompt"
])
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
decoder_inputs
=
None
else
:
decoder_inputs
=
self
.
_prompt_to_llm_inputs
(
decoder_input
,
request_id
=
request_id
,
)
decoder_inputs
=
self
.
_prompt_to_llm_inputs
(
decoder_input
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
and
(
...
...
@@ -616,10 +598,7 @@ class InputPreprocessor:
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
,
request_id
=
request_id
,
)
inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
# Encoder-Decoder Multimodal model
...
...
@@ -636,7 +615,6 @@ class InputPreprocessor:
async
def
_process_encoder_decoder_prompt_async
(
self
,
prompt
:
PromptType
,
request_id
:
str
,
)
->
EncoderDecoderInputs
:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_inputs
:
SingletonInputs
...
...
@@ -644,18 +622,13 @@ class InputPreprocessor:
if
is_explicit_encoder_decoder_prompt
(
prompt
):
encoder_task
=
self
.
_prompt_to_llm_inputs_async
(
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
)
prompt
[
"encoder_prompt"
])
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
encoder_inputs
=
await
encoder_task
decoder_inputs
=
None
else
:
decoder_task
=
self
.
_prompt_to_llm_inputs_async
(
decoder_input
,
request_id
=
request_id
,
)
decoder_task
=
self
.
_prompt_to_llm_inputs_async
(
decoder_input
)
encoder_inputs
,
decoder_inputs
=
await
asyncio
.
gather
(
encoder_task
,
decoder_task
)
...
...
@@ -668,10 +641,7 @@ class InputPreprocessor:
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
request_id
=
request_id
,
)
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
# Encoder-Decoder Multimodal model
...
...
@@ -704,7 +674,6 @@ class InputPreprocessor:
def
_process_decoder_only_prompt
(
self
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
...
...
@@ -716,7 +685,6 @@ class InputPreprocessor:
Arguments:
* prompt: input prompt
* request_id
* lora_request
* prompt_adapter_request
* return_mm_hashes
...
...
@@ -728,7 +696,6 @@ class InputPreprocessor:
prompt_comps
=
self
.
_prompt_to_llm_inputs
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
...
...
@@ -741,7 +708,6 @@ class InputPreprocessor:
async
def
_process_decoder_only_prompt_async
(
self
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
...
...
@@ -749,7 +715,6 @@ class InputPreprocessor:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
...
...
@@ -762,7 +727,6 @@ class InputPreprocessor:
def
preprocess
(
self
,
prompt
:
PromptType
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
...
...
@@ -774,10 +738,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1."
)
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return
self
.
_process_encoder_decoder_prompt
(
prompt
,
request_id
=
request_id
,
)
return
self
.
_process_encoder_decoder_prompt
(
prompt
)
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
...
...
@@ -786,7 +747,6 @@ class InputPreprocessor:
# Decoder-only operation
return
self
.
_process_decoder_only_prompt
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
return_mm_hashes
=
return_mm_hashes
,
...
...
@@ -795,7 +755,6 @@ class InputPreprocessor:
async
def
preprocess_async
(
self
,
prompt
:
PromptType
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
...
...
@@ -807,10 +766,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1."
)
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return
await
self
.
_process_encoder_decoder_prompt_async
(
prompt
,
request_id
=
request_id
,
)
return
await
self
.
_process_encoder_decoder_prompt_async
(
prompt
)
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
...
...
@@ -819,7 +775,6 @@ class InputPreprocessor:
# Decoder-only operation
return
await
self
.
_process_decoder_only_prompt_async
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
return_mm_hashes
=
return_mm_hashes
,
...
...
vllm/lora/punica_wrapper/punica_gpu.py
View file @
53076d70
...
...
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
import
torch
import
vllm.envs
as
envs
from
vllm.lora.layers
import
LoRAMapping
from
vllm.triton_utils
import
HAS_TRITON
...
...
@@ -42,8 +43,15 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self
.
token_mapping_meta
=
LoRAKernelMeta
.
make
(
self
.
max_loras
,
max_num_batched_tokens
,
device
=
device
)
# When cudagraph capture size is greater than max_num_seqs (max_batches,
# here), V0 captures the graph as if max_num_seqs is set to
# the capture size.
# V1 doesn't have this problem and always respects max_num_seqs.
max_num_prompts
=
(
max_batches
if
envs
.
VLLM_USE_V1
else
max_num_batched_tokens
)
self
.
prompt_mapping_meta
=
LoRAKernelMeta
.
make
(
self
.
max_loras
,
max_
batche
s
,
max_
num_prompt
s
,
device
=
device
)
def
update_metadata
(
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
53076d70
...
...
@@ -79,6 +79,12 @@ def maybe_backend_fallback(
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF."
,
"outlines"
)
elif
guided_params
.
json_object
:
# https://github.com/mlc-ai/xgrammar/issues/256
fallback_or_error
(
guided_params
,
"xgrammar does not support json_object."
,
"guidance"
)
# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif
not
xgr_installed
:
...
...
@@ -88,9 +94,9 @@ def maybe_backend_fallback(
if
(
guided_params
.
backend_name
==
"outlines"
and
guided_params
.
json_object
is
not
None
):
# outlines doesn't support json_object, fallback to
xgrammar
# outlines doesn't support json_object, fallback to
guidance
fallback_or_error
(
guided_params
,
"outlines does not support json_object."
,
"
xgrammar
"
)
"outlines does not support json_object."
,
"
guidance
"
)
return
guided_params
...
...
@@ -122,10 +128,15 @@ async def get_guided_decoding_logits_processor(
get_local_xgrammar_guided_decoding_logits_processor
)
return
get_local_xgrammar_guided_decoding_logits_processor
(
guided_params
,
tokenizer
,
model_config
,
reasoner
)
if
guided_params
.
backend_name
==
'guidance'
:
from
vllm.model_executor.guided_decoding.guidance_decoding
import
(
get_local_guidance_guided_decoding_logits_processor
)
return
get_local_guidance_guided_decoding_logits_processor
(
guided_params
,
tokenizer
)
raise
ValueError
(
f
"Unknown guided decoding backend '
{
guided_params
.
backend
}
'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'"
)
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)
def
get_local_guided_decoding_logits_processor
(
...
...
@@ -155,7 +166,13 @@ def get_local_guided_decoding_logits_processor(
get_local_xgrammar_guided_decoding_logits_processor
)
return
get_local_xgrammar_guided_decoding_logits_processor
(
guided_params
,
tokenizer
,
model_config
,
reasoner
)
if
guided_params
.
backend_name
==
'guidance'
:
from
vllm.model_executor.guided_decoding.guidance_decoding
import
(
get_local_guidance_guided_decoding_logits_processor
)
return
get_local_guidance_guided_decoding_logits_processor
(
guided_params
,
tokenizer
)
raise
ValueError
(
f
"Unknown guided decoding backend '
{
guided_params
.
backend
}
'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'"
)
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)
vllm/model_executor/guided_decoding/guidance_decoding.py
0 → 100644
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
from
re
import
escape
as
regex_escape
import
llguidance
from
transformers
import
PreTrainedTokenizerBase
from
vllm.model_executor.guided_decoding.guidance_logits_processors
import
(
GuidanceLogitsProcessor
)
from
vllm.sampling_params
import
GuidedDecodingParams
def
get_local_guidance_guided_decoding_logits_processor
(
guided_params
:
GuidedDecodingParams
,
tokenizer
:
PreTrainedTokenizerBase
)
->
GuidanceLogitsProcessor
:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
"""
grm
=
""
if
guided_params
.
json
:
grm
=
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
guided_params
.
json
,
overrides
=
{
"whitespace_pattern"
:
guided_params
.
whitespace_pattern
})
elif
guided_params
.
json_object
:
grm
=
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
'{"type": "object"}'
,
overrides
=
{
"whitespace_pattern"
:
guided_params
.
whitespace_pattern
})
elif
guided_params
.
regex
:
grm
=
llguidance
.
grammar_from
(
"regex"
,
guided_params
.
regex
)
elif
guided_params
.
choice
:
# choice just uses regex
choices
=
(
regex_escape
(
str
(
choice
))
for
choice
in
guided_params
.
choice
)
choices_regex
=
"("
+
"|"
.
join
(
choices
)
+
")"
grm
=
llguidance
.
grammar_from
(
"regex"
,
choices_regex
)
elif
guided_params
.
grammar
:
# this supports Lark and GBNF
grm
=
llguidance
.
grammar_from
(
"grammar"
,
guided_params
.
grammar
)
if
grm
:
return
GuidanceLogitsProcessor
(
grm
,
tokenizer
)
raise
ValueError
(
"Unknown guided decoding mode"
)
vllm/model_executor/guided_decoding/guidance_logits_processors.py
0 → 100644
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Any
,
List
import
llguidance
import
llguidance.hf
import
llguidance.torch
import
torch
from
transformers
import
PreTrainedTokenizerBase
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
GuidanceLogitsProcessor
:
"""Base Guidance Logits Processor"""
cached_tokenizers
:
dict
[
str
,
Any
]
=
{}
def
__init__
(
self
,
grammar
:
str
,
tokenizer
:
PreTrainedTokenizerBase
,
)
->
None
:
"""Base Guidance Logits Processor
Args:
grammar (str)
grammar to guide the generation
tokenizer (PreTrainedTokenizerBase)
model's tokenizer
"""
self
.
grammar
=
grammar
self
.
tokenizer
=
tokenizer
self
.
tokenizer_name
=
tokenizer
.
name_or_path
self
.
new_sampling
=
False
self
.
initialized
=
False
def
_initialize
(
self
):
if
self
.
initialized
:
return
ll_tokenizer
=
self
.
cached_tokenizers
.
get
(
self
.
tokenizer
.
name_or_path
,
None
)
if
ll_tokenizer
is
None
:
ll_tokenizer
=
llguidance
.
hf
.
from_tokenizer
(
self
.
tokenizer
,
None
)
self
.
cached_tokenizers
[
self
.
tokenizer
.
name_or_path
]
=
ll_tokenizer
self
.
ll_tokenizer
=
ll_tokenizer
self
.
ll_matcher
=
llguidance
.
LLMatcher
(
self
.
ll_tokenizer
,
self
.
grammar
,
log_level
=
int
(
os
.
environ
.
get
(
"LLGUIDANCE_LOG_LEVEL"
,
"1"
)),
)
# create reusable bitmask
self
.
bitmask
=
llguidance
.
torch
.
allocate_token_bitmask
(
1
,
self
.
ll_tokenizer
.
vocab_size
)
self
.
initialized
=
True
def
__call__
(
self
,
input_ids
:
List
[
int
],
scores
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# we initialize the guidance model here
# to avoid pickling ll_tokenizer and ll_interpreter
self
.
_initialize
()
if
self
.
new_sampling
and
len
(
input_ids
)
>
0
:
self
.
ll_matcher
.
consume_token
(
input_ids
[
-
1
])
err
=
self
.
ll_matcher
.
get_error
()
if
err
:
logger
.
warning
(
"Error in LLMatcher: %s"
,
err
)
llguidance
.
torch
.
fill_next_token_bitmask
(
self
.
ll_matcher
,
self
.
bitmask
,
0
)
llguidance
.
torch
.
apply_token_bitmask_inplace
(
scores
,
self
.
bitmask
.
to
(
scores
.
device
))
self
.
new_sampling
=
True
return
scores
Prev
1
…
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment