Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
...@@ -31,7 +31,9 @@ def _get_embedding( ...@@ -31,7 +31,9 @@ def _get_embedding(
if encoding_format == "float": if encoding_format == "float":
return output.embedding return output.embedding
elif encoding_format == "base64": elif encoding_format == "base64":
embedding_bytes = np.array(output.embedding).tobytes() # Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8") return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format) assert_never(encoding_format)
......
...@@ -4,7 +4,7 @@ from vllm.config import ModelConfig ...@@ -4,7 +4,7 @@ from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (apply_chat_template, from vllm.entrypoints.chat_utils import (apply_chat_template,
load_chat_template, load_chat_template,
parse_chat_messages) parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -43,7 +43,11 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -43,7 +43,11 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger=request_logger) request_logger=request_logger)
# If this is None we use the tokenizer's default chat template # If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template) # the list of commonly-used chat template names for HF named templates
hf_chat_templates: List[str] = ['default', 'tool_use']
self.chat_template = chat_template \
if chat_template in hf_chat_templates \
else load_chat_template(chat_template)
async def create_tokenize( async def create_tokenize(
self, self,
...@@ -65,10 +69,11 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -65,10 +69,11 @@ class OpenAIServingTokenization(OpenAIServing):
if isinstance(request, TokenizeChatRequest): if isinstance(request, TokenizeChatRequest):
model_config = self.model_config model_config = self.model_config
conversation, mm_futures = parse_chat_messages( conversation, mm_data_future = parse_chat_messages_futures(
request.messages, model_config, tokenizer) request.messages, model_config, tokenizer)
if mm_futures: mm_data = await mm_data_future
if mm_data:
logger.warning( logger.warning(
"Multi-modal inputs are ignored during tokenization") "Multi-modal inputs are ignored during tokenization")
......
from .abstract_tool_parser import ToolParser
from .hermes_tool_parser import Hermes2ProToolParser
from .mistral_tool_parser import MistralToolParser
__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"]
\ No newline at end of file
from typing import Dict, List, Sequence, Union
from vllm.entrypoints.openai.protocol import (DeltaMessage,
ExtractedToolCallInformation)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
class ToolParser:
"""
Abstract ToolParser class that should not be used directly. Provided
properties and methods should be used in
derived classes.
"""
def __init__(self, tokenizer: AnyTokenizer):
self.prev_tool_call_arr: List[Dict] = []
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = []
self.model_tokenizer = tokenizer
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
"""
Static method that should be implemented for extracting tool calls from
a complete model-generated string.
Used for non-streaming responses where we have the entire model response
available before sending to the client.
Static because it's stateless.
"""
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls has not been implemented!")
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Instance method that should be implemented for extracting tool calls
from an incomplete response; for use when handling tool calls and
streaming. Has to be an instance method because it requires state -
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls_streaming has not been "
"implemented!")
import json
import re
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
InitialDeltaToolCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
logger = init_logger(__name__)
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
logger.error(
"Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = self.model_tokenizer.tokenizer
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
self.scratch_pad_regex = re.compile(
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id: int = self.model_tokenizer.vocab[
self.tool_call_start_token]
self.tool_call_end_token_id: int = self.model_tokenizer.vocab[
self.tool_call_end_token]
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = (
self.tool_call_regex.findall(model_output))
# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls = [
json.loads(match[0] if match[0] else match[1])
for match in function_call_tuples
]
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"])))
for function_call in raw_function_calls
]
content = model_output[:model_output.
find(self.tool_call_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None)
except Exception as e:
logger.error("Error in extracting tool call from response %s",
e)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_call_start_token_id not in current_token_ids:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id)
prev_tool_end_count = previous_token_ids.count(
self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id)
cur_tool_end_count = current_token_ids.count(
self.tool_call_end_token_id)
# case: if we're generating text, OR rounding out a tool call
if (cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count):
logger.debug("Generating text content! skipping tool parsing.")
if delta_text != self.tool_call_end_token:
return DeltaMessage(content=delta_text)
# case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here
# something with tools with this diff.
# flags for partial JSON parting. exported constants from
# "Allow" are handled via BIT MASK
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
# case -- we're starting a new tool call
if (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.current_tool_initial_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count > prev_tool_end_count):
diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
if diff:
diff = json.dumps(diff).replace(
self.streamed_args_for_tool[self.current_tool_id], "")
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s", diff)
self.streamed_args_for_tool[self.current_tool_id] \
+= diff
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
# case -- otherwise we're just generating text
else:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
return delta
try:
current_tool_call = partial_json_parser.loads(
tool_call_portion or "{}",
flags) if tool_call_portion else None
logger.debug("Parsed tool call %s", current_tool_call)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# case - we haven't sent the initial delta with the tool call ID
# (it will be sent)
if not self.current_tool_initial_sent:
self.current_tool_initial_sent = True
return DeltaMessage(tool_calls=[
InitialDeltaToolCall(
index=self.current_tool_id).model_dump(
exclude_none=True)
])
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
elif not self.current_tool_name_sent:
function_name: Union[str, None] = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = DeltaMessage(content=delta_text) \
if text_portion is not None else None
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug("Trying to parse current tool call with ID %s",
self.current_tool_id)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = (
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error("should be impossible to have arguments reset "
"mid-call. skipping streaming anything.")
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
logger.debug("finding %s in %s", delta_text,
cur_arguments_json)
# get the location where previous args differ from current
args_delta_start_loc = cur_arguments_json.index(delta_text) \
+ len(delta_text)
# use that to find the actual delta
arguments_delta = cur_arguments_json[:args_delta_start_loc]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= arguments_delta
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
logger.debug("Searching for diff between\n%s", cur_args_json)
logger.debug("and\n%s", prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got argument diff %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= argument_diff
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = \
current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
return None # do not stream a delta. skip this token ID.
import json
import re
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
InitialDeltaToolCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
logger = init_logger(__name__)
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
self.model_tokenizer = self.model_tokenizer.tokenizer
else:
logger.info("Non-Mistral tokenizer detected when using a Mistral "
"model...")
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes!
"""
# case -- if a tool call token is not present, return a text response
if self.bot_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
# use a regex to find the tool call. remove the BOT token
# and make sure to replace single quotes with double quotes
raw_tool_call = self.tool_call_regex.findall(
model_output.replace(self.bot_token, ""))[0]
# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: List[ToolCall] = [
ToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"])))
for raw_function_call in function_call_arr
]
# get any content before the tool call
content = model_output.split(self.bot_token)[0]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if len(content) > 0 else None)
except Exception as e:
logger.error("Error in extracting tool call from response: %s", e)
print("ERROR", e)
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.bot_token_id not in current_token_ids:
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the BOT token which means the start of tool
# calling
if (self.bot_token_id in delta_token_ids
and len(delta_token_ids) == 1):
# if it's the only token, return None, so we don't send a chat
# completion any don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr = current_text.split(self.bot_token)[1]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: List[Dict] = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: Union[str, None] = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff).replace(
self.streamed_args_for_tool[self.current_tool_id],
"")
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.current_tool_initial_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# case: update an existing tool - this is handled below
# if the current tool initial data incl. the id, type=function
# and idx not sent, send that
if not self.current_tool_initial_sent:
self.current_tool_initial_sent = True
delta = DeltaMessage(tool_calls=[
InitialDeltaToolCall(
index=self.current_tool_id).model_dump(
exclude_none=True)
])
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("\'", "\"")
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
logger.debug("finding %s in %s", new_text,
cur_arguments_json)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(new_text) +
len(new_text)]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def find_common_prefix(s1: str, s2: str) -> str:
"""
Finds a common prefix that is shared between two strings, if there is one.
Order of arguments is NOT important.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely.
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
'{"fruit": "ap'
"""
prefix = ''
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def find_common_suffix(s1: str, s2: str) -> str:
"""
Finds a common suffix shared between two strings, if there is one. Order of
arguments is NOT important.
Stops when the suffix ends OR it hits an alphanumeric character
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
"""
suffix = ''
min_length = min(len(s1), len(s2))
for i in range(1, min_length + 1):
if s1[-i] == s2[-i] and not s1[-i].isalnum():
suffix = s1[-i] + suffix
else:
break
return suffix
def extract_intermediate_diff(curr: str, old: str) -> str:
"""
Given two strings, extract the difference in the middle between two strings
that are known to have a common prefix and/or suffix.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely. The order of arguments IS
important - the new version of the partially-parsed JSON must be the first
argument, and the secnod argument must be from the previous generation.
What it returns, is tokens that should be streamed to the client.
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
-> 'ple'
"""
suffix = find_common_suffix(curr, old)
old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
prefix = find_common_prefix(curr, old)
diff = curr
if len(suffix):
diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
if len(prefix):
# replace the prefix only once in case it's mirrored
diff = diff.replace(prefix, '', 1)
return diff
def find_all_indices(string, substring):
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
"""
indices = []
index = -1
while True:
index = string.find(substring, index + 1)
if index == -1:
break
indices.append(index)
return indices
...@@ -35,6 +35,7 @@ if TYPE_CHECKING: ...@@ -35,6 +35,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False VLLM_USE_FLASHINFER_SAMPLER: bool = False
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_OMP_THREADS_BIND: str = ""
...@@ -220,6 +221,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -220,6 +221,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Internal flag to enable Dynamo graph capture # Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE": "VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
lambda:
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
...@@ -372,7 +377,7 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -372,7 +377,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
os.path.join(get_default_cache_root(), "vllm", "xla_cache"), os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
)), )),
"VLLM_FUSED_MOE_CHUNK_SIZE": "VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")), lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),
# If set, vllm will skip the deprecation warnings. # If set, vllm will skip the deprecation warnings.
"VLLM_NO_DEPRECATION_WARNING": "VLLM_NO_DEPRECATION_WARNING":
...@@ -424,6 +429,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -424,6 +429,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_TORCH_PROFILER_DIR": "VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -11,8 +11,9 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ...@@ -11,8 +11,9 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor) ResultHandler, WorkerMonitor)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port, from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
......
...@@ -6,7 +6,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase ...@@ -6,7 +6,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -64,8 +65,9 @@ class DistributedGPUExecutor(GPUExecutor): ...@@ -64,8 +65,9 @@ class DistributedGPUExecutor(GPUExecutor):
num_cpu_blocks=num_cpu_blocks) num_cpu_blocks=num_cpu_blocks)
def execute_model( def execute_model(
self, self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers( self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop", "start_worker_execution_loop",
...@@ -188,7 +190,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): ...@@ -188,7 +190,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
@abstractmethod @abstractmethod
async def _driver_execute_model_async( async def _driver_execute_model_async(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
"""Execute the model asynchronously in the driver worker. """Execute the model asynchronously in the driver worker.
......
...@@ -6,8 +6,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -6,8 +6,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest
class ExecutorBase(ABC): class ExecutorBase(ABC):
......
...@@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union ...@@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
...@@ -176,5 +177,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): ...@@ -176,5 +177,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
) -> List[Union[SamplerOutput, PoolerOutput]]: ) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, ) )(execute_model_req=execute_model_req)
return output return output
...@@ -14,7 +14,8 @@ from vllm.executor.gpu_executor import create_worker ...@@ -14,7 +14,8 @@ from vllm.executor.gpu_executor import create_worker
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor) ResultHandler, WorkerMonitor)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils import maybe_set_triton_cache_manager from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
get_distributed_init_method, get_open_port, get_distributed_init_method, get_open_port,
...@@ -30,16 +31,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -30,16 +31,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
uses_ray: bool = False uses_ray: bool = False
def _init_executor(self) -> None: def _init_executor(self) -> None:
self._check_executor_parameters()
# Create the parallel GPU workers. # Create the parallel GPU workers.
world_size = self.parallel_config.world_size world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size tensor_parallel_size = self.parallel_config.tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
...@@ -68,16 +65,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -68,16 +65,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if world_size > 1: if world_size > 1:
maybe_set_triton_cache_manager() maybe_set_triton_cache_manager()
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")
assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")
# Multiprocessing-based executor does not support multi-node setting. # Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address # Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication. # 127.0.0.1 for communication.
...@@ -139,6 +126,26 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -139,6 +126,26 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers) max_parallel_loading_workers)
def _check_executor_parameters(self):
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")
assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")
def shutdown(self): def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor", if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None: None)) is not None:
......
import vllm.envs as envs
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync)
from vllm.executor.xpu_executor import XPUExecutor
from vllm.logger import init_logger
from vllm.utils import make_async
logger = init_logger(__name__)
class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor):
"""Python multiprocessing-based multi-XPU executor"""
def _check_executor_parameters(self):
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
if mp_method != "spawn":
raise RuntimeError(
"XPU multiprocess executor only support spawn as mp method")
class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor,
MultiprocessingGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
...@@ -3,8 +3,10 @@ from typing import List, Set, Tuple ...@@ -3,8 +3,10 @@ from typing import List, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.utils import make_async from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -24,14 +26,17 @@ class NeuronExecutor(ExecutorBase): ...@@ -24,14 +26,17 @@ class NeuronExecutor(ExecutorBase):
def _init_worker(self): def _init_worker(self):
from vllm.worker.neuron_worker import NeuronWorker from vllm.worker.neuron_worker import NeuronWorker
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = NeuronWorker( self.driver_worker = NeuronWorker(
self.model_config, model_config=self.model_config,
self.parallel_config, parallel_config=self.parallel_config,
self.scheduler_config, scheduler_config=self.scheduler_config,
self.device_config, device_config=self.device_config,
self.cache_config, cache_config=self.cache_config,
) local_rank=0,
rank=0,
distributed_init_method=distributed_init_method)
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
......
...@@ -9,7 +9,8 @@ from vllm.config import CacheConfig, ModelConfig ...@@ -9,7 +9,8 @@ from vllm.config import CacheConfig, ModelConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
get_open_port, make_async) get_open_port, make_async)
......
...@@ -12,7 +12,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable ...@@ -12,7 +12,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
from vllm.executor.msgspec_utils import encode_hook from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method, from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, get_vllm_instance_id, get_ip, get_open_port, get_vllm_instance_id,
make_async) make_async)
......
...@@ -10,7 +10,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase ...@@ -10,7 +10,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.executor.tpu_executor import TPUExecutor from vllm.executor.tpu_executor import TPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
...@@ -70,6 +71,19 @@ class RayTPUExecutor(TPUExecutor): ...@@ -70,6 +71,19 @@ class RayTPUExecutor(TPUExecutor):
worker_module_name = "vllm.worker.tpu_worker" worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker" worker_class_name = "TPUWorker"
# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# need to override the Ray environment variables manually.
override_env = {}
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
override_env.update({
"TPU_CHIPS_PER_HOST_BOUNDS":
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
})
if "TPU_HOST_BOUNDS" in os.environ:
override_env.update(
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
worker = ray.remote( worker = ray.remote(
num_cpus=0, num_cpus=0,
resources={"TPU": 1}, resources={"TPU": 1},
...@@ -80,6 +94,8 @@ class RayTPUExecutor(TPUExecutor): ...@@ -80,6 +94,8 @@ class RayTPUExecutor(TPUExecutor):
worker_class_name=worker_class_name, worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
if override_env:
worker.override_env_vars.remote(override_env)
worker_ip = ray.get(worker.get_node_ip.remote()) worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None: if worker_ip == driver_ip and self.driver_dummy_worker is None:
...@@ -95,12 +111,40 @@ class RayTPUExecutor(TPUExecutor): ...@@ -95,12 +111,40 @@ class RayTPUExecutor(TPUExecutor):
# Else, added to the list of workers. # Else, added to the list of workers.
self.workers.append(worker) self.workers.append(worker)
logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if self.driver_dummy_worker is None: if self.driver_dummy_worker is None:
raise ValueError( raise ValueError(
"Ray does not allocate any TPUs on the driver node. Consider " "Ray does not allocate any TPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a " "adjusting the Ray placement group or running the driver on a "
"TPU node.") "TPU node.")
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# Get the set of TPU IDs used on each node. # Get the set of TPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True) use_dummy_driver=True)
......
import os
import time import time
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
...@@ -84,6 +85,9 @@ try: ...@@ -84,6 +85,9 @@ try:
return output return output
def override_env_vars(self, vars: Dict[str, str]):
os.environ.update(vars)
ray_import_err = None ray_import_err = None
except ImportError as e: except ImportError as e:
...@@ -291,3 +295,28 @@ def initialize_ray_cluster( ...@@ -291,3 +295,28 @@ def initialize_ray_cluster(
_verify_bundles(current_placement_group, parallel_config, device_str) _verify_bundles(current_placement_group, parallel_config, device_str)
# Set the placement group in the parallel config # Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group parallel_config.placement_group = current_placement_group
def get_num_tpu_nodes() -> int:
from ray._private.accelerators import TPUAcceleratorManager
cluster_resources = ray.cluster_resources()
total_tpus = int(cluster_resources["TPU"])
tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
assert total_tpus % tpus_per_node == 0
return total_tpus // tpus_per_node
def get_num_nodes_in_placement_group() -> int:
pg_table = ray.util.placement_group_table()
current_pg = ray.util.get_current_placement_group()
num_nodes = 0
if current_pg:
nodes_in_pg = set()
for pg_key, pg in pg_table.items():
if pg_key == current_pg.id.hex():
for _, node in pg["bundles_to_node_id"].items():
nodes_in_pg.add(node)
num_nodes = len(nodes_in_pg)
return num_nodes
...@@ -5,7 +5,8 @@ import torch ...@@ -5,7 +5,8 @@ import torch
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment