Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0640f227
Commit
0640f227
authored
Sep 09, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.0' into v0.6.0-dev
parents
82f1ffdf
32e7db25
Changes
335
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
966 additions
and
44 deletions
+966
-44
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+3
-1
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+9
-4
vllm/entrypoints/openai/tool_parsers/__init__.py
vllm/entrypoints/openai/tool_parsers/__init__.py
+5
-0
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
+58
-0
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
+344
-0
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+293
-0
vllm/entrypoints/openai/tool_parsers/utils.py
vllm/entrypoints/openai/tool_parsers/utils.py
+87
-0
vllm/envs.py
vllm/envs.py
+10
-1
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+2
-1
vllm/executor/distributed_gpu_executor.py
vllm/executor/distributed_gpu_executor.py
+6
-4
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+2
-1
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+3
-2
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+24
-17
vllm/executor/multiproc_xpu_executor.py
vllm/executor/multiproc_xpu_executor.py
+26
-0
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+14
-9
vllm/executor/openvino_executor.py
vllm/executor/openvino_executor.py
+2
-1
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+2
-1
vllm/executor/ray_tpu_executor.py
vllm/executor/ray_tpu_executor.py
+45
-1
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+29
-0
vllm/executor/tpu_executor.py
vllm/executor/tpu_executor.py
+2
-1
No files found.
vllm/entrypoints/openai/serving_embedding.py
View file @
0640f227
...
@@ -31,7 +31,9 @@ def _get_embedding(
...
@@ -31,7 +31,9 @@ def _get_embedding(
if
encoding_format
==
"float"
:
if
encoding_format
==
"float"
:
return
output
.
embedding
return
output
.
embedding
elif
encoding_format
==
"base64"
:
elif
encoding_format
==
"base64"
:
embedding_bytes
=
np
.
array
(
output
.
embedding
).
tobytes
()
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
embedding_bytes
=
np
.
array
(
output
.
embedding
,
dtype
=
"float32"
).
tobytes
()
return
base64
.
b64encode
(
embedding_bytes
).
decode
(
"utf-8"
)
return
base64
.
b64encode
(
embedding_bytes
).
decode
(
"utf-8"
)
assert_never
(
encoding_format
)
assert_never
(
encoding_format
)
...
...
vllm/entrypoints/openai/serving_tokenization.py
View file @
0640f227
...
@@ -4,7 +4,7 @@ from vllm.config import ModelConfig
...
@@ -4,7 +4,7 @@ from vllm.config import ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
apply_chat_template
,
from
vllm.entrypoints.chat_utils
import
(
apply_chat_template
,
load_chat_template
,
load_chat_template
,
parse_chat_messages
)
parse_chat_messages
_futures
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -43,7 +43,11 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -43,7 +43,11 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger
=
request_logger
)
request_logger
=
request_logger
)
# If this is None we use the tokenizer's default chat template
# If this is None we use the tokenizer's default chat template
self
.
chat_template
=
load_chat_template
(
chat_template
)
# the list of commonly-used chat template names for HF named templates
hf_chat_templates
:
List
[
str
]
=
[
'default'
,
'tool_use'
]
self
.
chat_template
=
chat_template
\
if
chat_template
in
hf_chat_templates
\
else
load_chat_template
(
chat_template
)
async
def
create_tokenize
(
async
def
create_tokenize
(
self
,
self
,
...
@@ -65,10 +69,11 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -65,10 +69,11 @@ class OpenAIServingTokenization(OpenAIServing):
if
isinstance
(
request
,
TokenizeChatRequest
):
if
isinstance
(
request
,
TokenizeChatRequest
):
model_config
=
self
.
model_config
model_config
=
self
.
model_config
conversation
,
mm_future
s
=
parse_chat_messages
(
conversation
,
mm_
data_
future
=
parse_chat_messages
_futures
(
request
.
messages
,
model_config
,
tokenizer
)
request
.
messages
,
model_config
,
tokenizer
)
if
mm_futures
:
mm_data
=
await
mm_data_future
if
mm_data
:
logger
.
warning
(
logger
.
warning
(
"Multi-modal inputs are ignored during tokenization"
)
"Multi-modal inputs are ignored during tokenization"
)
...
...
vllm/entrypoints/openai/tool_parsers/__init__.py
0 → 100644
View file @
0640f227
from
.abstract_tool_parser
import
ToolParser
from
.hermes_tool_parser
import
Hermes2ProToolParser
from
.mistral_tool_parser
import
MistralToolParser
__all__
=
[
"ToolParser"
,
"Hermes2ProToolParser"
,
"MistralToolParser"
]
\ No newline at end of file
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
0 → 100644
View file @
0640f227
from
typing
import
Dict
,
List
,
Sequence
,
Union
from
vllm.entrypoints.openai.protocol
import
(
DeltaMessage
,
ExtractedToolCallInformation
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
class
ToolParser
:
"""
Abstract ToolParser class that should not be used directly. Provided
properties and methods should be used in
derived classes.
"""
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
self
.
prev_tool_call_arr
:
List
[
Dict
]
=
[]
# the index of the tool call that is currently being parsed
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_initial_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
List
[
str
]
=
[]
self
.
model_tokenizer
=
tokenizer
def
extract_tool_calls
(
self
,
model_output
:
str
)
->
ExtractedToolCallInformation
:
"""
Static method that should be implemented for extracting tool calls from
a complete model-generated string.
Used for non-streaming responses where we have the entire model response
available before sending to the client.
Static because it's stateless.
"""
raise
NotImplementedError
(
"AbstractToolParser.extract_tool_calls has not been implemented!"
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
Union
[
DeltaMessage
,
None
]:
"""
Instance method that should be implemented for extracting tool calls
from an incomplete response; for use when handling tool calls and
streaming. Has to be an instance method because it requires state -
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise
NotImplementedError
(
"AbstractToolParser.extract_tool_calls_streaming has not been "
"implemented!"
)
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
0 → 100644
View file @
0640f227
import
json
import
re
from
typing
import
Dict
,
List
,
Sequence
,
Union
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
vllm.entrypoints.openai.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
InitialDeltaToolCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
)
from
vllm.entrypoints.openai.tool_parsers.utils
import
(
extract_intermediate_diff
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
logger
=
init_logger
(
__name__
)
class
Hermes2ProToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
super
().
__init__
(
tokenizer
)
if
isinstance
(
self
.
model_tokenizer
,
MistralTokenizer
):
logger
.
error
(
"Detected Mistral tokenizer when using a Hermes model"
)
self
.
model_tokenizer
=
self
.
model_tokenizer
.
tokenizer
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
List
[
Dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
=
False
self
.
current_tool_initial_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
List
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
self
.
tool_call_start_token
:
str
=
"<tool_call>"
self
.
tool_call_end_token
:
str
=
"</tool_call>"
self
.
tool_call_regex
=
re
.
compile
(
r
"<tool_call>(.*?)</tool_call>|<tool_call>(.*)"
,
re
.
DOTALL
)
self
.
scratch_pad_regex
=
re
.
compile
(
r
"<scratch_pad>(.*?)</scratch_pad>"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self
.
tool_call_start_token_id
:
int
=
self
.
model_tokenizer
.
vocab
[
self
.
tool_call_start_token
]
self
.
tool_call_end_token_id
:
int
=
self
.
model_tokenizer
.
vocab
[
self
.
tool_call_end_token
]
if
not
self
.
tool_call_start_token_id
or
not
self
.
tool_call_end_token_id
:
raise
RuntimeError
(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
def
extract_tool_calls
(
self
,
model_output
:
str
)
->
ExtractedToolCallInformation
:
# sanity check; avoid unnecessary processing
if
self
.
tool_call_start_token
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
else
:
try
:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples
=
(
self
.
tool_call_regex
.
findall
(
model_output
))
# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls
=
[
json
.
loads
(
match
[
0
]
if
match
[
0
]
else
match
[
1
])
for
match
in
function_call_tuples
]
tool_calls
=
[
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
function_call
[
"name"
],
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
function_call
[
"arguments"
])))
for
function_call
in
raw_function_calls
]
content
=
model_output
[:
model_output
.
find
(
self
.
tool_call_start_token
)]
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
tool_calls
,
content
=
content
if
content
else
None
)
except
Exception
as
e
:
logger
.
error
(
"Error in extracting tool call from response %s"
,
e
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
Union
[
DeltaMessage
,
None
]:
logger
.
debug
(
"delta_text: %s"
,
delta_text
)
logger
.
debug
(
"delta_token_ids: %s"
,
delta_token_ids
)
# check to see if we should be streaming a tool call - is there a
if
self
.
tool_call_start_token_id
not
in
current_token_ids
:
logger
.
debug
(
"No tool call tokens found!"
)
return
DeltaMessage
(
content
=
delta_text
)
try
:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count
=
previous_token_ids
.
count
(
self
.
tool_call_start_token_id
)
prev_tool_end_count
=
previous_token_ids
.
count
(
self
.
tool_call_end_token_id
)
cur_tool_start_count
=
current_token_ids
.
count
(
self
.
tool_call_start_token_id
)
cur_tool_end_count
=
current_token_ids
.
count
(
self
.
tool_call_end_token_id
)
# case: if we're generating text, OR rounding out a tool call
if
(
cur_tool_start_count
==
cur_tool_end_count
and
prev_tool_end_count
==
cur_tool_end_count
):
logger
.
debug
(
"Generating text content! skipping tool parsing."
)
if
delta_text
!=
self
.
tool_call_end_token
:
return
DeltaMessage
(
content
=
delta_text
)
# case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here
# something with tools with this diff.
# flags for partial JSON parting. exported constants from
# "Allow" are handled via BIT MASK
flags
=
Allow
.
ALL
if
self
.
current_tool_name_sent
\
else
Allow
.
ALL
&
~
Allow
.
STR
# case -- we're starting a new tool call
if
(
cur_tool_start_count
>
cur_tool_end_count
and
cur_tool_start_count
>
prev_tool_start_count
):
if
len
(
delta_token_ids
)
>
1
:
tool_call_portion
=
current_text
.
split
(
self
.
tool_call_start_token
)[
-
1
]
else
:
tool_call_portion
=
None
delta
=
None
text_portion
=
None
# set cursors and state appropriately
self
.
current_tool_id
+=
1
self
.
current_tool_name_sent
=
False
self
.
current_tool_initial_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"Starting on a new tool %s"
,
self
.
current_tool_id
)
# case -- we're updating an existing tool call
elif
(
cur_tool_start_count
>
cur_tool_end_count
and
cur_tool_start_count
==
prev_tool_start_count
):
# get the portion of the text that's the tool call
tool_call_portion
=
current_text
.
split
(
self
.
tool_call_start_token
)[
-
1
]
text_portion
=
None
# case -- the current tool call is being closed.
elif
(
cur_tool_start_count
==
cur_tool_end_count
and
cur_tool_end_count
>
prev_tool_end_count
):
diff
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
if
diff
:
diff
=
json
.
dumps
(
diff
).
replace
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
],
""
)
logger
.
debug
(
"Finishing tool and found diff that had not "
"been streamed yet: %s"
,
diff
)
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
\
+=
diff
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
diff
).
model_dump
(
exclude_none
=
True
))
])
# case -- otherwise we're just generating text
else
:
text
=
delta_text
.
replace
(
self
.
tool_call_start_token
,
""
)
text
=
text
.
replace
(
self
.
tool_call_end_token
,
""
)
delta
=
DeltaMessage
(
tool_calls
=
[],
content
=
text
)
return
delta
try
:
current_tool_call
=
partial_json_parser
.
loads
(
tool_call_portion
or
"{}"
,
flags
)
if
tool_call_portion
else
None
logger
.
debug
(
"Parsed tool call %s"
,
current_tool_call
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
return
None
# case - we haven't sent the initial delta with the tool call ID
# (it will be sent)
if
not
self
.
current_tool_initial_sent
:
self
.
current_tool_initial_sent
=
True
return
DeltaMessage
(
tool_calls
=
[
InitialDeltaToolCall
(
index
=
self
.
current_tool_id
).
model_dump
(
exclude_none
=
True
)
])
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
elif
not
self
.
current_tool_name_sent
:
function_name
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"name"
)
if
function_name
:
self
.
current_tool_name_sent
=
True
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
))
])
else
:
return
None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if
tool_call_portion
is
None
:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta
=
DeltaMessage
(
content
=
delta_text
)
\
if
text_portion
is
not
None
else
None
return
delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger
.
debug
(
"Trying to parse current tool call with ID %s"
,
self
.
current_tool_id
)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_id
:
self
.
prev_tool_call_arr
.
append
({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments
=
(
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
))
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
logger
.
debug
(
"diffing old arguments: %s"
,
prev_arguments
)
logger
.
debug
(
"against new ones: %s"
,
cur_arguments
)
# case -- no arguments have been created yet. skip sending a delta.
if
not
cur_arguments
and
not
prev_arguments
:
logger
.
debug
(
"Skipping text %s - no arguments"
,
delta_text
)
delta
=
None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif
not
cur_arguments
and
prev_arguments
:
logger
.
error
(
"should be impossible to have arguments reset "
"mid-call. skipping streaming anything."
)
delta
=
None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif
cur_arguments
and
not
prev_arguments
:
cur_arguments_json
=
json
.
dumps
(
cur_arguments
)
logger
.
debug
(
"finding %s in %s"
,
delta_text
,
cur_arguments_json
)
# get the location where previous args differ from current
args_delta_start_loc
=
cur_arguments_json
.
index
(
delta_text
)
\
+
len
(
delta_text
)
# use that to find the actual delta
arguments_delta
=
cur_arguments_json
[:
args_delta_start_loc
]
logger
.
debug
(
"First tokens in arguments received: %s"
,
arguments_delta
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
arguments_delta
).
model_dump
(
exclude_none
=
True
))
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
\
+=
arguments_delta
# last case -- we have an update to existing arguments.
elif
cur_arguments
and
prev_arguments
:
cur_args_json
=
json
.
dumps
(
cur_arguments
)
prev_args_json
=
json
.
dumps
(
prev_arguments
)
logger
.
debug
(
"Searching for diff between
\n
%s"
,
cur_args_json
)
logger
.
debug
(
"and
\n
%s"
,
prev_args_json
)
argument_diff
=
extract_intermediate_diff
(
cur_args_json
,
prev_args_json
)
logger
.
debug
(
"got argument diff %s"
,
argument_diff
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
argument_diff
).
model_dump
(
exclude_none
=
True
))
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
\
+=
argument_diff
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if
self
.
current_tool_id
==
len
(
self
.
prev_tool_call_arr
)
-
1
:
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
\
current_tool_call
else
:
self
.
prev_tool_call_arr
.
append
(
current_tool_call
)
return
delta
except
Exception
as
e
:
logger
.
error
(
"Error trying to handle streaming tool call: %s"
,
e
)
return
None
# do not stream a delta. skip this token ID.
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
0 → 100644
View file @
0640f227
import
json
import
re
from
typing
import
Dict
,
List
,
Sequence
,
Union
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
vllm.entrypoints.openai.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
InitialDeltaToolCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
)
from
vllm.entrypoints.openai.tool_parsers.utils
import
(
extract_intermediate_diff
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
logger
=
init_logger
(
__name__
)
class
MistralToolParser
(
ToolParser
):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set
"""
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
super
().
__init__
(
tokenizer
)
if
isinstance
(
self
.
model_tokenizer
,
MistralTokenizer
):
self
.
model_tokenizer
=
self
.
model_tokenizer
.
tokenizer
else
:
logger
.
info
(
"Non-Mistral tokenizer detected when using a Mistral "
"model..."
)
# initialize properties used for state when parsing tool calls in
# streaming mode
self
.
prev_tool_call_arr
:
List
[
Dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_initial_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
List
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
self
.
bot_token
=
"[TOOL_CALLS]"
self
.
bot_token_id
=
self
.
model_tokenizer
.
vocab
[
self
.
bot_token
]
self
.
tool_call_regex
=
re
.
compile
(
r
"\[{.*?}\]"
,
re
.
DOTALL
)
def
extract_tool_calls
(
self
,
model_output
:
str
)
->
ExtractedToolCallInformation
:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes!
"""
# case -- if a tool call token is not present, return a text response
if
self
.
bot_token
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
try
:
# use a regex to find the tool call. remove the BOT token
# and make sure to replace single quotes with double quotes
raw_tool_call
=
self
.
tool_call_regex
.
findall
(
model_output
.
replace
(
self
.
bot_token
,
""
))[
0
]
# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr
=
json
.
loads
(
raw_tool_call
)
tool_calls
:
List
[
ToolCall
]
=
[
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
raw_function_call
[
"name"
],
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
raw_function_call
[
"arguments"
])))
for
raw_function_call
in
function_call_arr
]
# get any content before the tool call
content
=
model_output
.
split
(
self
.
bot_token
)[
0
]
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
tool_calls
,
content
=
content
if
len
(
content
)
>
0
else
None
)
except
Exception
as
e
:
logger
.
error
(
"Error in extracting tool call from response: %s"
,
e
)
print
(
"ERROR"
,
e
)
# return information to just treat the tool call as regular JSON
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
Union
[
DeltaMessage
,
None
]:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if
self
.
bot_token_id
not
in
current_token_ids
:
return
DeltaMessage
(
content
=
delta_text
)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the BOT token which means the start of tool
# calling
if
(
self
.
bot_token_id
in
delta_token_ids
and
len
(
delta_token_ids
)
==
1
):
# if it's the only token, return None, so we don't send a chat
# completion any don't send a control token
return
None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags
=
Allow
.
ALL
if
self
.
current_tool_name_sent
\
else
Allow
.
ALL
&
~
Allow
.
STR
try
:
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr
=
current_text
.
split
(
self
.
bot_token
)[
1
]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try
:
tool_call_arr
:
List
[
Dict
]
=
partial_json_parser
.
loads
(
parsable_arr
,
flags
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
return
None
# select as the current tool call the one we're on the state at
current_tool_call
:
Dict
=
tool_call_arr
[
self
.
current_tool_id
]
\
if
len
(
tool_call_arr
)
>
0
else
{}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if
len
(
tool_call_arr
)
==
0
:
return
None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif
(
len
(
tool_call_arr
)
>
0
and
len
(
tool_call_arr
)
>
self
.
current_tool_id
+
1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if
self
.
current_tool_id
>=
0
:
diff
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"arguments"
)
if
diff
:
diff
=
json
.
dumps
(
diff
).
replace
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
],
""
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
diff
).
model_dump
(
exclude_none
=
True
))
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
diff
else
:
delta
=
None
else
:
delta
=
None
# re-set stuff pertaining to progress in the current tool
self
.
current_tool_id
=
len
(
tool_call_arr
)
-
1
self
.
current_tool_name_sent
=
False
self
.
current_tool_initial_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"starting on new tool %d"
,
self
.
current_tool_id
)
return
delta
# case: update an existing tool - this is handled below
# if the current tool initial data incl. the id, type=function
# and idx not sent, send that
if
not
self
.
current_tool_initial_sent
:
self
.
current_tool_initial_sent
=
True
delta
=
DeltaMessage
(
tool_calls
=
[
InitialDeltaToolCall
(
index
=
self
.
current_tool_id
).
model_dump
(
exclude_none
=
True
)
])
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif
not
self
.
current_tool_name_sent
:
function_name
=
current_tool_call
.
get
(
"name"
)
if
function_name
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
))
])
self
.
current_tool_name_sent
=
True
else
:
delta
=
None
# now we know we're on the same tool call and we're streaming
# arguments
else
:
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
new_text
=
delta_text
.
replace
(
"
\'
"
,
"
\"
"
)
if
not
cur_arguments
and
not
prev_arguments
:
delta
=
None
elif
not
cur_arguments
and
prev_arguments
:
logger
.
error
(
"INVARIANT - impossible to have arguments reset "
"mid-arguments"
)
delta
=
None
elif
cur_arguments
and
not
prev_arguments
:
cur_arguments_json
=
json
.
dumps
(
cur_arguments
)
logger
.
debug
(
"finding %s in %s"
,
new_text
,
cur_arguments_json
)
arguments_delta
=
cur_arguments_json
[:
cur_arguments_json
.
index
(
new_text
)
+
len
(
new_text
)]
logger
.
debug
(
"First tokens in arguments received: %s"
,
arguments_delta
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
arguments_delta
).
model_dump
(
exclude_none
=
True
))
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
arguments_delta
elif
cur_arguments
and
prev_arguments
:
cur_args_json
=
json
.
dumps
(
cur_arguments
)
prev_args_json
=
json
.
dumps
(
prev_arguments
)
logger
.
debug
(
"Searching for diff between
\n
%s
\n
%s"
,
cur_args_json
,
prev_args_json
)
argument_diff
=
extract_intermediate_diff
(
cur_args_json
,
prev_args_json
)
logger
.
debug
(
"got arguments diff: %s"
,
argument_diff
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
argument_diff
).
model_dump
(
exclude_none
=
True
))
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
argument_diff
else
:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta
=
None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self
.
prev_tool_call_arr
=
tool_call_arr
return
delta
except
Exception
as
e
:
logger
.
error
(
"Error trying to handle streaming tool call: %s"
,
e
)
logger
.
debug
(
"Skipping chunk as a result of tool streaming extraction "
"error"
)
return
None
vllm/entrypoints/openai/tool_parsers/utils.py
0 → 100644
View file @
0640f227
def
find_common_prefix
(
s1
:
str
,
s2
:
str
)
->
str
:
"""
Finds a common prefix that is shared between two strings, if there is one.
Order of arguments is NOT important.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely.
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
'{"fruit": "ap'
"""
prefix
=
''
min_length
=
min
(
len
(
s1
),
len
(
s2
))
for
i
in
range
(
0
,
min_length
):
if
s1
[
i
]
==
s2
[
i
]:
prefix
+=
s1
[
i
]
else
:
break
return
prefix
def
find_common_suffix
(
s1
:
str
,
s2
:
str
)
->
str
:
"""
Finds a common suffix shared between two strings, if there is one. Order of
arguments is NOT important.
Stops when the suffix ends OR it hits an alphanumeric character
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
"""
suffix
=
''
min_length
=
min
(
len
(
s1
),
len
(
s2
))
for
i
in
range
(
1
,
min_length
+
1
):
if
s1
[
-
i
]
==
s2
[
-
i
]
and
not
s1
[
-
i
].
isalnum
():
suffix
=
s1
[
-
i
]
+
suffix
else
:
break
return
suffix
def
extract_intermediate_diff
(
curr
:
str
,
old
:
str
)
->
str
:
"""
Given two strings, extract the difference in the middle between two strings
that are known to have a common prefix and/or suffix.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely. The order of arguments IS
important - the new version of the partially-parsed JSON must be the first
argument, and the secnod argument must be from the previous generation.
What it returns, is tokens that should be streamed to the client.
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
-> 'ple'
"""
suffix
=
find_common_suffix
(
curr
,
old
)
old
=
old
[::
-
1
].
replace
(
suffix
[::
-
1
],
''
,
1
)[::
-
1
]
prefix
=
find_common_prefix
(
curr
,
old
)
diff
=
curr
if
len
(
suffix
):
diff
=
diff
[::
-
1
].
replace
(
suffix
[::
-
1
],
''
,
1
)[::
-
1
]
if
len
(
prefix
):
# replace the prefix only once in case it's mirrored
diff
=
diff
.
replace
(
prefix
,
''
,
1
)
return
diff
def
find_all_indices
(
string
,
substring
):
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
"""
indices
=
[]
index
=
-
1
while
True
:
index
=
string
.
find
(
substring
,
index
+
1
)
if
index
==
-
1
:
break
indices
.
append
(
index
)
return
indices
vllm/envs.py
View file @
0640f227
...
@@ -35,6 +35,7 @@ if TYPE_CHECKING:
...
@@ -35,6 +35,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
bool
=
False
VLLM_USE_FLASHINFER_SAMPLER
:
bool
=
False
VLLM_USE_FLASHINFER_REJECTION_SAMPLER
:
bool
=
False
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
...
@@ -220,6 +221,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -220,6 +221,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Internal flag to enable Dynamo graph capture
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
:
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
,
"0"
)),
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
,
"0"
)),
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
# the GPU device id
# the GPU device id
...
@@ -372,7 +377,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -372,7 +377,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
os
.
path
.
join
(
get_default_cache_root
(),
"vllm"
,
"xla_cache"
),
os
.
path
.
join
(
get_default_cache_root
(),
"vllm"
,
"xla_cache"
),
)),
)),
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"
65536
"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"
32768
"
)),
# If set, vllm will skip the deprecation warnings.
# If set, vllm will skip the deprecation warnings.
"VLLM_NO_DEPRECATION_WARNING"
:
"VLLM_NO_DEPRECATION_WARNING"
:
...
@@ -424,6 +429,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -424,6 +429,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_TORCH_PROFILER_DIR"
:
"VLLM_TORCH_PROFILER_DIR"
:
lambda
:
(
None
if
os
.
getenv
(
"VLLM_TORCH_PROFILER_DIR"
,
None
)
is
None
else
os
lambda
:
(
None
if
os
.
getenv
(
"VLLM_TORCH_PROFILER_DIR"
,
None
)
is
None
else
os
.
path
.
expanduser
(
os
.
getenv
(
"VLLM_TORCH_PROFILER_DIR"
,
"."
))),
.
path
.
expanduser
(
os
.
getenv
(
"VLLM_TORCH_PROFILER_DIR"
,
"."
))),
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_TRITON_AWQ"
,
"0"
))),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/executor/cpu_executor.py
View file @
0640f227
...
@@ -11,8 +11,9 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
...
@@ -11,8 +11,9 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler
,
WorkerMonitor
)
ResultHandler
,
WorkerMonitor
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
GiB_bytes
,
get_distributed_init_method
,
get_open_port
,
from
vllm.utils
import
(
GiB_bytes
,
get_distributed_init_method
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
get_vllm_instance_id
,
make_async
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
...
vllm/executor/distributed_gpu_executor.py
View file @
0640f227
...
@@ -6,7 +6,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase
...
@@ -6,7 +6,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -64,8 +65,9 @@ class DistributedGPUExecutor(GPUExecutor):
...
@@ -64,8 +65,9 @@ class DistributedGPUExecutor(GPUExecutor):
num_cpu_blocks
=
num_cpu_blocks
)
num_cpu_blocks
=
num_cpu_blocks
)
def
execute_model
(
def
execute_model
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
if
self
.
parallel_worker_tasks
is
None
:
if
self
.
parallel_worker_tasks
is
None
:
self
.
parallel_worker_tasks
=
self
.
_run_workers
(
self
.
parallel_worker_tasks
=
self
.
_run_workers
(
"start_worker_execution_loop"
,
"start_worker_execution_loop"
,
...
@@ -188,7 +190,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
...
@@ -188,7 +190,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
@
abstractmethod
@
abstractmethod
async
def
_driver_execute_model_async
(
async
def
_driver_execute_model_async
(
self
,
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
"""Execute the model asynchronously in the driver worker.
"""Execute the model asynchronously in the driver worker.
...
...
vllm/executor/executor_base.py
View file @
0640f227
...
@@ -6,8 +6,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -6,8 +6,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig
,
SchedulerConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
SpeculativeConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
class
ExecutorBase
(
ABC
):
class
ExecutorBase
(
ABC
):
...
...
vllm/executor/gpu_executor.py
View file @
0640f227
...
@@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
...
@@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
make_async
)
from
vllm.worker.worker_base
import
WorkerBase
,
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerBase
,
WorkerWrapperBase
...
@@ -176,5 +177,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
...
@@ -176,5 +177,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]]:
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
execute_model_req
=
execute_model_req
,
)
)(
execute_model_req
=
execute_model_req
)
return
output
return
output
vllm/executor/multiproc_gpu_executor.py
View file @
0640f227
...
@@ -14,7 +14,8 @@ from vllm.executor.gpu_executor import create_worker
...
@@ -14,7 +14,8 @@ from vllm.executor.gpu_executor import create_worker
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
ResultHandler
,
WorkerMonitor
)
ResultHandler
,
WorkerMonitor
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.triton_utils
import
maybe_set_triton_cache_manager
from
vllm.triton_utils
import
maybe_set_triton_cache_manager
from
vllm.utils
import
(
_run_task_with_lock
,
cuda_device_count_stateless
,
from
vllm.utils
import
(
_run_task_with_lock
,
cuda_device_count_stateless
,
get_distributed_init_method
,
get_open_port
,
get_distributed_init_method
,
get_open_port
,
...
@@ -30,16 +31,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
...
@@ -30,16 +31,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
uses_ray
:
bool
=
False
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
self
.
_check_executor_parameters
()
# Create the parallel GPU workers.
# Create the parallel GPU workers.
world_size
=
self
.
parallel_config
.
world_size
world_size
=
self
.
parallel_config
.
world_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
update_environment_variables
({
"CUDA_VISIBLE_DEVICES"
:
(
","
.
join
(
map
(
str
,
range
(
world_size
))))
})
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os
.
environ
[
"VLLM_INSTANCE_ID"
]
=
get_vllm_instance_id
()
os
.
environ
[
"VLLM_INSTANCE_ID"
]
=
get_vllm_instance_id
()
...
@@ -68,16 +65,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
...
@@ -68,16 +65,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if
world_size
>
1
:
if
world_size
>
1
:
maybe_set_triton_cache_manager
()
maybe_set_triton_cache_manager
()
cuda_device_count
=
cuda_device_count_stateless
()
# Use confusing message for more common TP-only case.
assert
tensor_parallel_size
<=
cuda_device_count
,
(
f
"please set tensor_parallel_size (
{
tensor_parallel_size
}
) "
f
"to less than max local gpu count (
{
cuda_device_count
}
)"
)
assert
world_size
<=
cuda_device_count
,
(
f
"please ensure that world_size (
{
world_size
}
) "
f
"is less than than max local gpu count (
{
cuda_device_count
}
)"
)
# Multiprocessing-based executor does not support multi-node setting.
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
# 127.0.0.1 for communication.
...
@@ -139,6 +126,26 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
...
@@ -139,6 +126,26 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers
=
self
.
parallel_config
.
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
)
max_parallel_loading_workers
)
def
_check_executor_parameters
(
self
):
world_size
=
self
.
parallel_config
.
world_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
update_environment_variables
({
"CUDA_VISIBLE_DEVICES"
:
(
","
.
join
(
map
(
str
,
range
(
world_size
))))
})
cuda_device_count
=
cuda_device_count_stateless
()
# Use confusing message for more common TP-only case.
assert
tensor_parallel_size
<=
cuda_device_count
,
(
f
"please set tensor_parallel_size (
{
tensor_parallel_size
}
) "
f
"to less than max local gpu count (
{
cuda_device_count
}
)"
)
assert
world_size
<=
cuda_device_count
,
(
f
"please ensure that world_size (
{
world_size
}
) "
f
"is less than than max local gpu count (
{
cuda_device_count
}
)"
)
def
shutdown
(
self
):
def
shutdown
(
self
):
if
(
worker_monitor
:
=
getattr
(
self
,
"worker_monitor"
,
if
(
worker_monitor
:
=
getattr
(
self
,
"worker_monitor"
,
None
))
is
not
None
:
None
))
is
not
None
:
...
...
vllm/executor/multiproc_xpu_executor.py
0 → 100644
View file @
0640f227
import
vllm.envs
as
envs
from
vllm.executor.multiproc_gpu_executor
import
(
MultiprocessingGPUExecutor
,
MultiprocessingGPUExecutorAsync
)
from
vllm.executor.xpu_executor
import
XPUExecutor
from
vllm.logger
import
init_logger
from
vllm.utils
import
make_async
logger
=
init_logger
(
__name__
)
class
MultiprocessingXPUExecutor
(
MultiprocessingGPUExecutor
,
XPUExecutor
):
"""Python multiprocessing-based multi-XPU executor"""
def
_check_executor_parameters
(
self
):
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
if
mp_method
!=
"spawn"
:
raise
RuntimeError
(
"XPU multiprocess executor only support spawn as mp method"
)
class
MultiprocessingXPUExecutorAsync
(
MultiprocessingXPUExecutor
,
MultiprocessingGPUExecutorAsync
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
driver_exec_model
=
make_async
(
self
.
driver_worker
.
execute_model
)
vllm/executor/neuron_executor.py
View file @
0640f227
...
@@ -3,8 +3,10 @@ from typing import List, Set, Tuple
...
@@ -3,8 +3,10 @@ from typing import List, Set, Tuple
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.utils
import
make_async
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -24,14 +26,17 @@ class NeuronExecutor(ExecutorBase):
...
@@ -24,14 +26,17 @@ class NeuronExecutor(ExecutorBase):
def
_init_worker
(
self
):
def
_init_worker
(
self
):
from
vllm.worker.neuron_worker
import
NeuronWorker
from
vllm.worker.neuron_worker
import
NeuronWorker
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
self
.
driver_worker
=
NeuronWorker
(
self
.
driver_worker
=
NeuronWorker
(
self
.
model_config
,
model_config
=
self
.
model_config
,
self
.
parallel_config
,
parallel_config
=
self
.
parallel_config
,
self
.
scheduler_config
,
scheduler_config
=
self
.
scheduler_config
,
self
.
device_config
,
device_config
=
self
.
device_config
,
self
.
cache_config
,
cache_config
=
self
.
cache_config
,
)
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
)
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
load_model
()
self
.
driver_worker
.
load_model
()
...
...
vllm/executor/openvino_executor.py
View file @
0640f227
...
@@ -9,7 +9,8 @@ from vllm.config import CacheConfig, ModelConfig
...
@@ -9,7 +9,8 @@ from vllm.config import CacheConfig, ModelConfig
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
GiB_bytes
,
get_distributed_init_method
,
get_ip
,
from
vllm.utils
import
(
GiB_bytes
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
get_open_port
,
make_async
)
...
...
vllm/executor/ray_gpu_executor.py
View file @
0640f227
...
@@ -12,7 +12,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
...
@@ -12,7 +12,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
from
vllm.executor.msgspec_utils
import
encode_hook
from
vllm.executor.msgspec_utils
import
encode_hook
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
_run_task_with_lock
,
get_distributed_init_method
,
from
vllm.utils
import
(
_run_task_with_lock
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
make_async
)
...
...
vllm/executor/ray_tpu_executor.py
View file @
0640f227
...
@@ -10,7 +10,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase
...
@@ -10,7 +10,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.executor.tpu_executor
import
TPUExecutor
from
vllm.executor.tpu_executor
import
TPUExecutor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
get_vllm_instance_id
,
make_async
)
...
@@ -70,6 +71,19 @@ class RayTPUExecutor(TPUExecutor):
...
@@ -70,6 +71,19 @@ class RayTPUExecutor(TPUExecutor):
worker_module_name
=
"vllm.worker.tpu_worker"
worker_module_name
=
"vllm.worker.tpu_worker"
worker_class_name
=
"TPUWorker"
worker_class_name
=
"TPUWorker"
# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# need to override the Ray environment variables manually.
override_env
=
{}
if
"TPU_CHIPS_PER_HOST_BOUNDS"
in
os
.
environ
:
override_env
.
update
({
"TPU_CHIPS_PER_HOST_BOUNDS"
:
os
.
environ
[
"TPU_CHIPS_PER_HOST_BOUNDS"
]
})
if
"TPU_HOST_BOUNDS"
in
os
.
environ
:
override_env
.
update
(
{
"TPU_HOST_BOUNDS"
:
os
.
environ
[
"TPU_HOST_BOUNDS"
]})
worker
=
ray
.
remote
(
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_cpus
=
0
,
resources
=
{
"TPU"
:
1
},
resources
=
{
"TPU"
:
1
},
...
@@ -80,6 +94,8 @@ class RayTPUExecutor(TPUExecutor):
...
@@ -80,6 +94,8 @@ class RayTPUExecutor(TPUExecutor):
worker_class_name
=
worker_class_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
)
if
override_env
:
worker
.
override_env_vars
.
remote
(
override_env
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
...
@@ -95,12 +111,40 @@ class RayTPUExecutor(TPUExecutor):
...
@@ -95,12 +111,40 @@ class RayTPUExecutor(TPUExecutor):
# Else, added to the list of workers.
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
self
.
workers
.
append
(
worker
)
logger
.
debug
(
"workers: %s"
,
self
.
workers
)
logger
.
debug
(
"driver_dummy_worker: %s"
,
self
.
driver_dummy_worker
)
if
self
.
driver_dummy_worker
is
None
:
if
self
.
driver_dummy_worker
is
None
:
raise
ValueError
(
raise
ValueError
(
"Ray does not allocate any TPUs on the driver node. Consider "
"Ray does not allocate any TPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"adjusting the Ray placement group or running the driver on a "
"TPU node."
)
"TPU node."
)
worker_ips
=
[
ray
.
get
(
worker
.
get_node_ip
.
remote
())
# type: ignore[attr-defined]
for
worker
in
self
.
workers
]
ip_counts
:
Dict
[
str
,
int
]
=
{}
for
ip
in
worker_ips
:
ip_counts
[
ip
]
=
ip_counts
.
get
(
ip
,
0
)
+
1
def
sort_by_driver_then_worker_ip
(
worker
):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
return
(
ip
!=
driver_ip
,
ip_counts
[
ip
],
ip
)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self
.
workers
=
sorted
(
self
.
workers
,
key
=
sort_by_driver_then_worker_ip
)
# Get the set of TPU IDs used on each node.
# Get the set of TPU IDs used on each node.
worker_node_and_gpu_ids
=
self
.
_run_workers
(
"get_node_and_gpu_ids"
,
worker_node_and_gpu_ids
=
self
.
_run_workers
(
"get_node_and_gpu_ids"
,
use_dummy_driver
=
True
)
use_dummy_driver
=
True
)
...
...
vllm/executor/ray_utils.py
View file @
0640f227
import
os
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
...
@@ -84,6 +85,9 @@ try:
...
@@ -84,6 +85,9 @@ try:
return
output
return
output
def
override_env_vars
(
self
,
vars
:
Dict
[
str
,
str
]):
os
.
environ
.
update
(
vars
)
ray_import_err
=
None
ray_import_err
=
None
except
ImportError
as
e
:
except
ImportError
as
e
:
...
@@ -291,3 +295,28 @@ def initialize_ray_cluster(
...
@@ -291,3 +295,28 @@ def initialize_ray_cluster(
_verify_bundles
(
current_placement_group
,
parallel_config
,
device_str
)
_verify_bundles
(
current_placement_group
,
parallel_config
,
device_str
)
# Set the placement group in the parallel config
# Set the placement group in the parallel config
parallel_config
.
placement_group
=
current_placement_group
parallel_config
.
placement_group
=
current_placement_group
def
get_num_tpu_nodes
()
->
int
:
from
ray._private.accelerators
import
TPUAcceleratorManager
cluster_resources
=
ray
.
cluster_resources
()
total_tpus
=
int
(
cluster_resources
[
"TPU"
])
tpus_per_node
=
TPUAcceleratorManager
.
get_current_node_num_accelerators
()
assert
total_tpus
%
tpus_per_node
==
0
return
total_tpus
//
tpus_per_node
def
get_num_nodes_in_placement_group
()
->
int
:
pg_table
=
ray
.
util
.
placement_group_table
()
current_pg
=
ray
.
util
.
get_current_placement_group
()
num_nodes
=
0
if
current_pg
:
nodes_in_pg
=
set
()
for
pg_key
,
pg
in
pg_table
.
items
():
if
pg_key
==
current_pg
.
id
.
hex
():
for
_
,
node
in
pg
[
"bundles_to_node_id"
].
items
():
nodes_in_pg
.
add
(
node
)
num_nodes
=
len
(
nodes_in_pg
)
return
num_nodes
vllm/executor/tpu_executor.py
View file @
0640f227
...
@@ -5,7 +5,8 @@ import torch
...
@@ -5,7 +5,8 @@ import torch
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
make_async
)
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment