Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
......@@ -110,6 +110,17 @@ class OpenAIServingEmbedding(OpenAIServing):
request_id = f"embd-{random_uuid()}"
created_time = int(time.monotonic())
truncate_prompt_tokens = None
if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try:
......@@ -123,11 +134,9 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params = request.to_pooling_params()
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.input,
))
self._tokenize_prompt_input_or_inputs(request, tokenizer,
request.input,
truncate_prompt_tokens))
for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}"
......@@ -148,6 +157,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params,
request_id_item,
lora_request=lora_request,
priority=request.priority,
)
generators.append(generator)
......
......@@ -27,11 +27,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter
......@@ -168,15 +166,6 @@ class OpenAIServing:
})
return json_str
async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
guided_decoding_backend, request, tokenizer)
async def _check_model(
self,
request: AnyRequest,
......@@ -382,7 +371,8 @@ class OpenAIServing:
self,
request_id: str,
inputs: Union[str, List[int], TextTokensPrompt],
params: Optional[Union[SamplingParams, PoolingParams]],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
......
......@@ -87,6 +87,7 @@ class OpenAIServingTokenization(OpenAIServing):
messages=request.messages,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
)
else:
prompt = apply_hf_chat_template(
......@@ -94,6 +95,7 @@ class OpenAIServingTokenization(OpenAIServing):
conversation=conversation,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
)
else:
prompt = request.prompt
......
from .abstract_tool_parser import ToolParser
from .abstract_tool_parser import ToolParser, ToolParserManager
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser
__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"]
\ No newline at end of file
__all__ = [
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
]
from typing import Dict, List, Sequence, Union
import importlib
import importlib.util
import os
from functools import cached_property
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
from vllm.entrypoints.openai.protocol import (DeltaMessage,
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import is_list_of
logger = init_logger(__name__)
......@@ -24,8 +30,22 @@ class ToolParser:
self.model_tokenizer = tokenizer
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
@cached_property
def vocab(self) -> Dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""
Static method that used to adjust the request parameters.
"""
return request
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Static method that should be implemented for extracting tool calls from
a complete model-generated string.
......@@ -44,6 +64,7 @@ class ToolParser:
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""
Instance method that should be implemented for extracting tool calls
......@@ -55,3 +76,86 @@ class ToolParser:
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls_streaming has not been "
"implemented!")
class ToolParserManager:
tool_parsers: Dict[str, Type] = {}
@classmethod
def get_tool_parser(cls, name) -> Type:
"""
Get tool parser by name which is registered by `register_module`.
Raise a KeyError exception if the name is not registered.
"""
if name in cls.tool_parsers:
return cls.tool_parsers[name]
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
@classmethod
def _register_module(cls,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = True) -> None:
if not issubclass(module, ToolParser):
raise TypeError(
f'module must be subclass of ToolParser, but got {type(module)}'
)
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in cls.tool_parsers:
existed_module = cls.tool_parsers[name]
raise KeyError(f'{name} is already registered '
f'at {existed_module.__module__}')
cls.tool_parsers[name] = module
@classmethod
def register_module(
cls,
name: Optional[Union[str, List[str]]] = None,
force: bool = True,
module: Union[Type, None] = None) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# raise the error ahead of time
if not (name is None or isinstance(name, str)
or is_list_of(name, str)):
raise TypeError(
'name must be None, an instance of str, or a sequence of str, '
f'but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module
return _register
@classmethod
def import_tool_parser(cls, plugin_path: str) -> None:
"""
Import a user defined tool parser by the path of the tool parser define
file.
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
if spec is None or spec.loader is None:
logger.error("load %s from %s failed.", module_name, plugin_path)
return
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
......@@ -5,12 +5,13 @@ from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
......@@ -20,6 +21,7 @@ from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module("hermes")
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
......@@ -48,17 +50,19 @@ class Hermes2ProToolParser(ToolParser):
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id: int = self.model_tokenizer.vocab[
self.tool_call_start_token]
self.tool_call_end_token_id: int = self.model_tokenizer.vocab[
self.tool_call_end_token]
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_call_start_token not in model_output:
......@@ -99,9 +103,9 @@ class Hermes2ProToolParser(ToolParser):
tool_calls=tool_calls,
content=content if content else None)
except Exception as e:
logger.error("Error in extracting tool call from response %s",
e)
except Exception:
logger.exception(
"Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
......@@ -114,6 +118,7 @@ class Hermes2ProToolParser(ToolParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
logger.debug("delta_text: %s", delta_text)
......@@ -328,6 +333,6 @@ class Hermes2ProToolParser(ToolParser):
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID.
import json
from typing import Dict, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module(["internlm"])
class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.position = 0
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
# do not skip special tokens because internlm use the special
# tokens to indicated the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def get_argments(self, obj):
if "parameters" in obj:
return obj.get("parameters")
elif "arguments" in obj:
return obj.get("arguments")
return None
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if '<|action_start|>' not in current_text:
self.position = len(current_text)
return DeltaMessage(content=delta_text)
# if the tool call is sended, return a empty delta message
# to make sure the finish_reason will be send correctly.
if self.current_tool_id > 0:
return DeltaMessage(content='')
last_pos = self.position
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
return None
new_delta = current_text[last_pos:]
text, action = new_delta.split('<|action_start|><|plugin|>')
if len(text) > 0:
self.position = self.position + len(text)
return DeltaMessage(content=text)
action = action.strip()
action = action.split('<|action_end|>'.strip())[0]
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
parsable_arr = action
# tool calls are generated in an object in inernlm2
# it's not support parallel tool calls
try:
tool_call_arr: Dict = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = tool_call_arr.get("name")
if function_name:
self.current_tool_id = self.current_tool_id + 1
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
self.streamed_args_for_tool.append("")
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.get_argments(
self.prev_tool_call_arr[self.current_tool_id])
cur_arguments = self.get_argments(tool_call_arr)
# not arguments generated
if not cur_arguments and not prev_arguments:
delta = None
# will never happen
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
# first time to get parameters
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(delta_text) +
len(delta_text)]
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
# both prev and cur parameters, send the increase parameters
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
self.prev_tool_call_arr = [tool_call_arr]
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
text = model_output
tools = request.tools
if '<|action_start|><|plugin|>' in text:
text, action = text.split('<|action_start|><|plugin|>')
action = action.split('<|action_end|>'.strip())[0]
action = action[action.find('{'):]
action_dict = json.loads(action)
name, parameters = action_dict['name'], json.dumps(
action_dict.get('parameters', action_dict.get('arguments',
{})))
if not tools or name not in [t.function.name for t in tools]:
ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)
tool_calls = [
ToolCall(
function=FunctionCall(name=name, arguments=parameters))
]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=text if len(text) > 0 else None)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)
import json
import re
from json import JSONDecodeError, JSONDecoder
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
# partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str, flags):
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
else:
raise
def is_complete_json(input_str):
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
@ToolParserManager.register_module("llama3_json")
class Llama3JsonToolParser(ToolParser):
"""
Tool call parser for Llama 3.1 models intended for use with the
examples/tool_chat_template_llama.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "<|python_tag|>"
self.bot_token_id = tokenizer.encode(self.bot_token,
add_special_tokens=False)[0]
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
# case -- if a tool call token is not present, return a text response
if not (model_output.startswith(self.bot_token)
or model_output.startswith('{')):
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
# load the JSON, and then use it to build the Function and
# Tool Call
dec = JSONDecoder()
function_call_arr = []
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = len(self.bot_token) if model_output.startswith(
self.bot_token) else 0
while start_idx < len(model_output):
(obj, end_idx) = dec.raw_decode(model_output[start_idx:])
start_idx += end_idx + len('; ')
function_call_arr.append(obj)
tool_calls: List[ToolCall] = [
ToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"] \
if "arguments" in raw_function_call \
else raw_function_call["parameters"])))
for raw_function_call in function_call_arr
]
# get any content before the tool call
ret = ExtractedToolCallInformation(tools_called=True,
tool_calls=tool_calls,
content=None)
return ret
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not (current_text.startswith(self.bot_token)
or current_text.startswith('{')):
return DeltaMessage(content=delta_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = len(self.bot_token) if current_text.startswith(
self.bot_token) else 0
while start_idx < len(current_text):
(obj,
end_idx) = partial_json_loads(current_text[start_idx:],
flags)
is_complete.append(
is_complete_json(current_text[start_idx:start_idx +
end_idx]))
start_idx += end_idx + len('; ')
# depending on the prompt Llama can use
# either arguments or parameters
if "parameters" in obj:
assert "arguments" not in obj, \
"model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
delta = None
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(
prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
import json
import re
from random import choices
from string import ascii_letters, digits
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from pydantic import Field
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
......@@ -19,7 +23,21 @@ from vllm.utils import random_uuid
logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
class MistralToolCall(ToolCall):
id: str = Field(
default_factory=lambda: MistralToolCall.generate_random_id())
@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))
@ToolParserManager.register_module("mistral")
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
......@@ -31,9 +49,7 @@ class MistralToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
self.model_tokenizer = self.model_tokenizer.tokenizer
else:
if not isinstance(self.model_tokenizer, MistralTokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral "
"model...")
......@@ -45,11 +61,18 @@ class MistralToolParser(ToolParser):
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
if not self.bot_token_id:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!")
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
......@@ -71,8 +94,8 @@ class MistralToolParser(ToolParser):
# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: List[ToolCall] = [
ToolCall(
tool_calls: List[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
......@@ -88,8 +111,8 @@ class MistralToolParser(ToolParser):
tool_calls=tool_calls,
content=content if len(content) > 0 else None)
except Exception as e:
logger.error("Error in extracting tool call from response: %s", e)
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
......@@ -103,6 +126,7 @@ class MistralToolParser(ToolParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# if the tool call token is not in the tokens generated so far, append
......@@ -274,8 +298,8 @@ class MistralToolParser(ToolParser):
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
......
......@@ -38,6 +38,7 @@ if TYPE_CHECKING:
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_OPENVINO_DEVICE: str = "CPU"
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
......@@ -65,7 +66,9 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_DISABLED_KERNELS: List[str] = []
def get_default_cache_root():
......@@ -214,27 +217,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")),
# If set, allowing the use of deprecated beam search implementation
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1",
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
lambda:
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")),
# Internal flag to control whether we use custom op,
# or use the native pytorch implementation
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS":
lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool(
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
......@@ -319,6 +307,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_CPU_OMP_THREADS_BIND":
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),
# OpenVINO device selection
# default is CPU
"VLLM_OPENVINO_DEVICE":
lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(),
# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE":
......@@ -413,6 +406,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda:
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),
"VLLM_TEST_FORCE_LOAD_FORMAT":
lambda: os.getenv("VLLM_TEST_FORCE_LOAD_FORMAT", "dummy"),
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
......@@ -441,6 +436,21 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda:
(os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in
("1", "true")),
# By default, vLLM will check the peer-to-peer capability itself,
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa
# If this env var is set to 1, vLLM will skip the peer-to-peer check,
# and trust the driver's peer-to-peer capability report.
"VLLM_SKIP_P2P_CHECK":
lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1",
# List of quantization kernels that should be disabled, used for testing
# and performance comparisons. Currently only affects MPLinearKernel
# selection
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
"VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),
}
# end-env-vars-definition
......
......@@ -28,6 +28,8 @@ class CPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
assert self.lora_config is None, "cpu backend doesn't support LoRA"
#
......@@ -324,6 +326,8 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.float16:
logger.warning("float16 is not supported on CPU, casting to bfloat16.")
config.dtype = torch.bfloat16
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if not config.enforce_eager:
logger.warning(
"CUDA graph is not supported on CPU, fallback to the eager "
......@@ -334,6 +338,8 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
def _verify_and_get_scheduler_config(
config: SchedulerConfig) -> SchedulerConfig:
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if config.chunked_prefill_enabled:
logger.warning("Chunked prefill is not supported on CPU, disable it.")
config.chunked_prefill_enabled = False
......@@ -342,6 +348,8 @@ def _verify_and_get_scheduler_config(
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if config.enable_prefix_caching:
logger.warning("Prefix caching is not supported on CPU, disable it.")
config.enable_prefix_caching = False
......
......@@ -56,6 +56,10 @@ class DistributedGPUExecutor(GPUExecutor):
# have GPUs.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len, max_concurrency)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
......
......@@ -121,6 +121,10 @@ class GPUExecutor(ExecutorBase):
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len, max_concurrency)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
......
......@@ -15,8 +15,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async,
cuda_is_initialized, get_distributed_init_method,
get_open_port, get_vllm_instance_id, make_async,
update_environment_variables)
logger = init_logger(__name__)
......@@ -122,6 +122,13 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
......
......@@ -3,7 +3,6 @@ import multiprocessing
import os
import sys
import threading
import traceback
import uuid
from dataclasses import dataclass
from multiprocessing import Queue
......@@ -27,9 +26,6 @@ RESET = '\033[0;0m'
JOIN_TIMEOUT_S = 2
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
mp = multiprocessing.get_context(mp_method)
@dataclass
class Result(Generic[T]):
......@@ -77,7 +73,7 @@ class ResultHandler(threading.Thread):
def __init__(self) -> None:
super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.result_queue = get_mp_context().Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
def run(self):
......@@ -147,10 +143,11 @@ class ProcessWorkerWrapper:
def __init__(self, result_handler: ResultHandler,
worker_factory: Callable[[], Any]) -> None:
self._task_queue = mp.Queue()
self.mp = get_mp_context()
self._task_queue = self.mp.Queue()
self.result_queue = result_handler.result_queue
self.tasks = result_handler.tasks
self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined]
target=_run_worker_process,
name="VllmWorkerProcess",
kwargs=dict(
......@@ -204,7 +201,7 @@ def _run_worker_process(
"""Worker process event loop"""
# Add process-specific prefix to stdout and stderr
process_name = mp.current_process().name
process_name = get_mp_context().current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
......@@ -229,10 +226,9 @@ def _run_worker_process(
except KeyboardInterrupt:
break
except BaseException as e:
tb = traceback.format_exc()
logger.error(
"Exception in worker %s while processing method %s: %s, %s",
process_name, method, e, tb)
logger.exception(
"Exception in worker %s while processing method %s.",
process_name, method)
exception = e
result_queue.put(
Result(task_id=task_id, value=output, exception=exception))
......@@ -268,4 +264,9 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file.start_new_line = False # type: ignore[attr-defined]
file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign]
\ No newline at end of file
file.write = write_with_prefix # type: ignore[method-assign]
def get_mp_context():
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)
......@@ -17,6 +17,14 @@ from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
logger = init_logger(__name__)
def is_openvino_cpu() -> bool:
return "CPU" in envs.VLLM_OPENVINO_DEVICE
def is_openvino_gpu() -> bool:
return "GPU" in envs.VLLM_OPENVINO_DEVICE
class OpenVINOExecutor(ExecutorBase):
uses_ray: bool = False
......@@ -24,8 +32,13 @@ class OpenVINOExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "openvino"
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
assert is_openvino_cpu() or is_openvino_gpu(), \
"OpenVINO backend supports only CPU and GPU devices"
self.ov_core = ov.Core()
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.cache_config = _verify_and_get_cache_config(
self.ov_core, self.cache_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
......@@ -40,6 +53,7 @@ class OpenVINOExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = OpenVINOWorker(
ov_core=self.ov_core,
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
......@@ -68,10 +82,13 @@ class OpenVINOExecutor(ExecutorBase):
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
# NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
logger.info("# CPU blocks: %d", num_gpu_blocks)
# NOTE: In case of a CPU device, `cpu block` for OpenVINO backend
# is located on CPU memory but is referred as `gpu block`.
# Because we want to reuse the existing block management procedure.
device_blocks = num_gpu_blocks
swap_blocks = num_cpu_blocks
logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d",
envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
......@@ -143,29 +160,45 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
return config
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
def _verify_and_get_cache_config(ov_core: ov.Core,
config: CacheConfig) -> CacheConfig:
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
logger.info("KV cache type is overried to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
config.cache_dtype = ov.Type.u8
if not is_openvino_cpu():
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
"ignored for GPU, f16 data type will be used.")
config.cache_dtype = ov.Type.f16
else:
logger.info("KV cache type is overridden to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
config.cache_dtype = ov.Type.u8
else:
core = ov.Core()
inference_precision = core.get_property("CPU",
hints.inference_precision)
if inference_precision == ov.Type.bf16:
config.cache_dtype = ov.Type.bf16
if is_openvino_cpu():
ov_device = envs.VLLM_OPENVINO_DEVICE
inference_precision = ov_core.get_property(
ov_device, hints.inference_precision)
if inference_precision == ov.Type.bf16:
config.cache_dtype = ov.Type.bf16
else:
config.cache_dtype = ov.Type.f16
else:
config.cache_dtype = ov.Type.f16
if config.block_size != 32:
logger.info(
f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 32
if is_openvino_cpu():
if config.block_size != 32:
logger.info(
f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 32
else:
if config.block_size != 16:
logger.info(
f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 16
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0:
if kv_cache_space == 0:
if kv_cache_space == 0 and is_openvino_cpu():
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
......
from contextlib import contextmanager
from typing import Any
_forward_context: Any = None
def get_forward_context() -> Any:
"""Get the current forward context."""
return _forward_context
@contextmanager
def set_forward_context(context: Any):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
global _forward_context
prev_context = _forward_context
_forward_context = context
try:
yield
finally:
_forward_context = prev_context
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
......@@ -16,11 +17,14 @@ See also:
__all__ = [
"TextPrompt",
"TokensPrompt",
"PromptInputs",
"SingletonPromptInputs",
"PromptType",
"SingletonPrompt",
"ExplicitEncoderDecoderPrompt",
"LLMInputs",
"EncoderDecoderLLMInputs",
"TokenInputs",
"token_inputs",
"SingletonInputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
......@@ -28,3 +32,34 @@ __all__ = [
"InputContext",
"InputRegistry",
]
def __getattr__(name: str):
import warnings
if name == "PromptInput":
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PromptType
if name == "LLMInputs":
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return DecoderOnlyInputs
if name == "EncoderDecoderLLMInputs":
msg = (
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return EncoderDecoderInputs
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple,
Union)
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Optional, Tuple, Union, cast)
from typing_extensions import NotRequired, TypedDict, TypeVar
......@@ -19,6 +19,14 @@ class TextPrompt(TypedDict):
if the model supports it.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt."""
......@@ -32,10 +40,18 @@ class TokensPrompt(TypedDict):
if the model supports it.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
"""
Set of possible schemas for a single LLM input:
Set of possible schemas for a single prompt:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
......@@ -46,7 +62,7 @@ which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPromptInputs` may be employed
A prompt of type :class:`SingletonPrompt` may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
......@@ -55,41 +71,44 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
_T1_co = TypeVar("_T1_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
bound=SingletonPrompt,
default=SingletonPrompt,
covariant=True)
_T2_co = TypeVar("_T2_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
bound=SingletonPrompt,
default=SingletonPrompt,
covariant=True)
# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
"""Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
"""
Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
:class:`SingletonPromptInputs` schemas, and are not
required to have the same schema.
The encoder and decoder prompts, respectively, may be formatted
according to any of the :class:`SingletonPrompt` schemas,
and are not required to have the same schema.
Only the encoder prompt may have multi-modal data.
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
and that the :code:`encoder_prompt` and :code:`decoder_prompt`
fields of this data structure themselves must be
:class:`SingletonPromptInputs` instances.
:class:`SingletonPrompt` instances.
"""
encoder_prompt: _T1_co
decoder_prompt: Optional[_T2_co]
mm_processor_kwargs: NotRequired[Dict[str, Any]]
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
......@@ -101,13 +120,8 @@ both decoder-only and encoder/decoder input types:
"""
class LLMInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
class TokenInputs(TypedDict):
"""Represents token-based inputs."""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
......@@ -122,8 +136,49 @@ class LLMInputs(TypedDict):
if the model supports it.
"""
mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
def token_inputs(
prompt_token_ids: List[int],
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(prompt_token_ids=prompt_token_ids)
if prompt is not None:
inputs["prompt"] = prompt
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if mm_processor_kwargs is not None:
inputs["mm_processor_kwargs"] = mm_processor_kwargs
return inputs
class EncoderDecoderLLMInputs(LLMInputs):
SingletonInputs = TokenInputs
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
DecoderOnlyInputs = TokenInputs
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
class EncoderDecoderInputs(TokenInputs):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
......@@ -146,33 +201,51 @@ class EncoderDecoderLLMInputs(LLMInputs):
"""
_T1 = TypeVar("_T1",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
_T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
def build_explicit_enc_dec_prompt(
encoder_prompt: _T1,
decoder_prompt: Optional[_T2],
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt,
mm_processor_kwargs=mm_processor_kwargs)
def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1],
dec_prompts: Iterable[Optional[_T2]],
mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]],
Dict[str, Any]]] = None,
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Zip encoder and decoder prompts together into a list of
:class:`ExplicitEncoderDecoderPrompt` instances.
:class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs
may also be provided; if a dict is passed, the same dictionary will be
used for every encoder/decoder prompt. If an iterable is provided, it will
be zipped with the encoder/decoder prompts.
"""
if mm_processor_kwargs is None:
mm_processor_kwargs = cast(Dict[str, Any], {})
if isinstance(mm_processor_kwargs, dict):
return [
build_explicit_enc_dec_prompt(
encoder_prompt, decoder_prompt,
cast(Dict[str, Any], mm_processor_kwargs))
for (encoder_prompt,
decoder_prompt) in zip(enc_prompts, dec_prompts)
]
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
mm_proc_kwargs)
for (encoder_prompt, decoder_prompt, mm_proc_kwargs
) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
]
......@@ -182,3 +255,34 @@ def to_enc_dec_tuple_list(
return [(enc_dec_prompt["encoder_prompt"],
enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts]
def __getattr__(name: str):
import warnings
if name == "PromptInput":
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PromptType
if name == "LLMInputs":
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return DecoderOnlyInputs
if name == "EncoderDecoderLLMInputs":
msg = (
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return EncoderDecoderInputs
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
from typing import List, Literal, Sequence, TypedDict, Union, overload
from typing import List, Literal, Sequence, TypedDict, Union, cast, overload
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt)
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt,
TextPrompt, TokensPrompt)
class ParsedText(TypedDict):
......@@ -44,13 +44,16 @@ def parse_and_batch_prompt(
if is_list_of(prompt, str):
# case 2: array of strings
prompt = cast(List[str], prompt)
return [
ParsedText(content=elem, is_tokens=False) for elem in prompt
]
if is_list_of(prompt, int):
# case 3: array of tokens
prompt = cast(List[int], prompt)
return [ParsedTokens(content=prompt, is_tokens=True)]
if is_list_of(prompt, list):
prompt = cast(List[List[int]], prompt)
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
......@@ -81,26 +84,26 @@ class ParsedTokensPrompt(TypedDict):
def parse_singleton_prompt(
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
if isinstance(inputs, str):
return ParsedStrPrompt(type="str", content=inputs)
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
if "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens",
content=inputs) # type: ignore
elif "prompt" in inputs:
return ParsedTextPrompt(type="text", content=inputs)
content=prompt) # type: ignore
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_explicit_encoder_decoder_prompt(
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(inputs, dict) and "encoder_prompt" in inputs
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_valid_encoder_decoder_llm_inputs(
inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
) -> TypeIs[EncoderDecoderLLMInputs]:
def is_encoder_decoder_inputs(
inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
) -> TypeIs[EncoderDecoderInputs]:
return "encoder_prompt_token_ids" in inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment