Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
...@@ -110,6 +110,17 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -110,6 +110,17 @@ class OpenAIServingEmbedding(OpenAIServing):
request_id = f"embd-{random_uuid()}" request_id = f"embd-{random_uuid()}"
created_time = int(time.monotonic()) created_time = int(time.monotonic())
truncate_prompt_tokens = None
if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try: try:
...@@ -123,11 +134,9 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -123,11 +134,9 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
prompts = list( prompts = list(
self._tokenize_prompt_input_or_inputs( self._tokenize_prompt_input_or_inputs(request, tokenizer,
request, request.input,
tokenizer, truncate_prompt_tokens))
request.input,
))
for i, prompt_inputs in enumerate(prompts): for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
...@@ -148,6 +157,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -148,6 +157,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
priority=request.priority,
) )
generators.append(generator) generators.append(generator)
......
...@@ -27,11 +27,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -27,11 +27,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter from vllm.utils import AtomicCounter
...@@ -168,15 +166,6 @@ class OpenAIServing: ...@@ -168,15 +166,6 @@ class OpenAIServing:
}) })
return json_str return json_str
async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
guided_decoding_backend, request, tokenizer)
async def _check_model( async def _check_model(
self, self,
request: AnyRequest, request: AnyRequest,
...@@ -382,7 +371,8 @@ class OpenAIServing: ...@@ -382,7 +371,8 @@ class OpenAIServing:
self, self,
request_id: str, request_id: str,
inputs: Union[str, List[int], TextTokensPrompt], inputs: Union[str, List[int], TextTokensPrompt],
params: Optional[Union[SamplingParams, PoolingParams]], params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: ) -> None:
......
...@@ -87,6 +87,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -87,6 +87,7 @@ class OpenAIServingTokenization(OpenAIServing):
messages=request.messages, messages=request.messages,
chat_template=self.chat_template, chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
) )
else: else:
prompt = apply_hf_chat_template( prompt = apply_hf_chat_template(
...@@ -94,6 +95,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -94,6 +95,7 @@ class OpenAIServingTokenization(OpenAIServing):
conversation=conversation, conversation=conversation,
chat_template=self.chat_template, chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
) )
else: else:
prompt = request.prompt prompt = request.prompt
......
from .abstract_tool_parser import ToolParser from .abstract_tool_parser import ToolParser, ToolParserManager
from .hermes_tool_parser import Hermes2ProToolParser from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser from .mistral_tool_parser import MistralToolParser
__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"] __all__ = [
\ No newline at end of file "ToolParser", "ToolParserManager", "Hermes2ProToolParser",
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
]
from typing import Dict, List, Sequence, Union import importlib
import importlib.util
import os
from functools import cached_property
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
from vllm.entrypoints.openai.protocol import (DeltaMessage, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation) ExtractedToolCallInformation)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import is_list_of
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -24,8 +30,22 @@ class ToolParser: ...@@ -24,8 +30,22 @@ class ToolParser:
self.model_tokenizer = tokenizer self.model_tokenizer = tokenizer
def extract_tool_calls(self, @cached_property
model_output: str) -> ExtractedToolCallInformation: def vocab(self) -> Dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""
Static method that used to adjust the request parameters.
"""
return request
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
""" """
Static method that should be implemented for extracting tool calls from Static method that should be implemented for extracting tool calls from
a complete model-generated string. a complete model-generated string.
...@@ -44,6 +64,7 @@ class ToolParser: ...@@ -44,6 +64,7 @@ class ToolParser:
previous_token_ids: Sequence[int], previous_token_ids: Sequence[int],
current_token_ids: Sequence[int], current_token_ids: Sequence[int],
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]: ) -> Union[DeltaMessage, None]:
""" """
Instance method that should be implemented for extracting tool calls Instance method that should be implemented for extracting tool calls
...@@ -55,3 +76,86 @@ class ToolParser: ...@@ -55,3 +76,86 @@ class ToolParser:
raise NotImplementedError( raise NotImplementedError(
"AbstractToolParser.extract_tool_calls_streaming has not been " "AbstractToolParser.extract_tool_calls_streaming has not been "
"implemented!") "implemented!")
class ToolParserManager:
tool_parsers: Dict[str, Type] = {}
@classmethod
def get_tool_parser(cls, name) -> Type:
"""
Get tool parser by name which is registered by `register_module`.
Raise a KeyError exception if the name is not registered.
"""
if name in cls.tool_parsers:
return cls.tool_parsers[name]
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
@classmethod
def _register_module(cls,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = True) -> None:
if not issubclass(module, ToolParser):
raise TypeError(
f'module must be subclass of ToolParser, but got {type(module)}'
)
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in cls.tool_parsers:
existed_module = cls.tool_parsers[name]
raise KeyError(f'{name} is already registered '
f'at {existed_module.__module__}')
cls.tool_parsers[name] = module
@classmethod
def register_module(
cls,
name: Optional[Union[str, List[str]]] = None,
force: bool = True,
module: Union[Type, None] = None) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# raise the error ahead of time
if not (name is None or isinstance(name, str)
or is_list_of(name, str)):
raise TypeError(
'name must be None, an instance of str, or a sequence of str, '
f'but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module
return _register
@classmethod
def import_tool_parser(cls, plugin_path: str) -> None:
"""
Import a user defined tool parser by the path of the tool parser define
file.
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
if spec is None or spec.loader is None:
logger.error("load %s from %s failed.", module_name, plugin_path)
return
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
...@@ -5,12 +5,13 @@ from typing import Dict, List, Sequence, Union ...@@ -5,12 +5,13 @@ from typing import Dict, List, Sequence, Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
ExtractedToolCallInformation, ExtractedToolCallInformation,
FunctionCall, ToolCall) FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser) ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import ( from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff) extract_intermediate_diff)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -20,6 +21,7 @@ from vllm.utils import random_uuid ...@@ -20,6 +21,7 @@ from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ToolParserManager.register_module("hermes")
class Hermes2ProToolParser(ToolParser): class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: AnyTokenizer):
...@@ -48,17 +50,19 @@ class Hermes2ProToolParser(ToolParser): ...@@ -48,17 +50,19 @@ class Hermes2ProToolParser(ToolParser):
raise ValueError( raise ValueError(
"The model tokenizer must be passed to the ToolParser " "The model tokenizer must be passed to the ToolParser "
"constructor during construction.") "constructor during construction.")
self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token] self.tool_call_start_token)
self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self.tool_call_end_token]
if not self.tool_call_start_token_id or not self.tool_call_end_token_id: if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
raise RuntimeError( raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end " "Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!") "tokens in the tokenizer!")
def extract_tool_calls(self, def extract_tool_calls(
model_output: str) -> ExtractedToolCallInformation: self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing # sanity check; avoid unnecessary processing
if self.tool_call_start_token not in model_output: if self.tool_call_start_token not in model_output:
...@@ -99,9 +103,9 @@ class Hermes2ProToolParser(ToolParser): ...@@ -99,9 +103,9 @@ class Hermes2ProToolParser(ToolParser):
tool_calls=tool_calls, tool_calls=tool_calls,
content=content if content else None) content=content if content else None)
except Exception as e: except Exception:
logger.error("Error in extracting tool call from response %s", logger.exception(
e) "Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False, return ExtractedToolCallInformation(tools_called=False,
tool_calls=[], tool_calls=[],
content=model_output) content=model_output)
...@@ -114,6 +118,7 @@ class Hermes2ProToolParser(ToolParser): ...@@ -114,6 +118,7 @@ class Hermes2ProToolParser(ToolParser):
previous_token_ids: Sequence[int], previous_token_ids: Sequence[int],
current_token_ids: Sequence[int], current_token_ids: Sequence[int],
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]: ) -> Union[DeltaMessage, None]:
logger.debug("delta_text: %s", delta_text) logger.debug("delta_text: %s", delta_text)
...@@ -328,6 +333,6 @@ class Hermes2ProToolParser(ToolParser): ...@@ -328,6 +333,6 @@ class Hermes2ProToolParser(ToolParser):
return delta return delta
except Exception as e: except Exception:
logger.error("Error trying to handle streaming tool call: %s", e) logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID. return None # do not stream a delta. skip this token ID.
import json
from typing import Dict, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module(["internlm"])
class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.position = 0
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
# do not skip special tokens because internlm use the special
# tokens to indicated the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def get_argments(self, obj):
if "parameters" in obj:
return obj.get("parameters")
elif "arguments" in obj:
return obj.get("arguments")
return None
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if '<|action_start|>' not in current_text:
self.position = len(current_text)
return DeltaMessage(content=delta_text)
# if the tool call is sended, return a empty delta message
# to make sure the finish_reason will be send correctly.
if self.current_tool_id > 0:
return DeltaMessage(content='')
last_pos = self.position
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
return None
new_delta = current_text[last_pos:]
text, action = new_delta.split('<|action_start|><|plugin|>')
if len(text) > 0:
self.position = self.position + len(text)
return DeltaMessage(content=text)
action = action.strip()
action = action.split('<|action_end|>'.strip())[0]
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
parsable_arr = action
# tool calls are generated in an object in inernlm2
# it's not support parallel tool calls
try:
tool_call_arr: Dict = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = tool_call_arr.get("name")
if function_name:
self.current_tool_id = self.current_tool_id + 1
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
self.streamed_args_for_tool.append("")
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.get_argments(
self.prev_tool_call_arr[self.current_tool_id])
cur_arguments = self.get_argments(tool_call_arr)
# not arguments generated
if not cur_arguments and not prev_arguments:
delta = None
# will never happen
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
# first time to get parameters
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(delta_text) +
len(delta_text)]
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
# both prev and cur parameters, send the increase parameters
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
self.prev_tool_call_arr = [tool_call_arr]
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
text = model_output
tools = request.tools
if '<|action_start|><|plugin|>' in text:
text, action = text.split('<|action_start|><|plugin|>')
action = action.split('<|action_end|>'.strip())[0]
action = action[action.find('{'):]
action_dict = json.loads(action)
name, parameters = action_dict['name'], json.dumps(
action_dict.get('parameters', action_dict.get('arguments',
{})))
if not tools or name not in [t.function.name for t in tools]:
ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)
tool_calls = [
ToolCall(
function=FunctionCall(name=name, arguments=parameters))
]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=text if len(text) > 0 else None)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)
import json
import re
from json import JSONDecodeError, JSONDecoder
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
# partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str, flags):
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
else:
raise
def is_complete_json(input_str):
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
@ToolParserManager.register_module("llama3_json")
class Llama3JsonToolParser(ToolParser):
"""
Tool call parser for Llama 3.1 models intended for use with the
examples/tool_chat_template_llama.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "<|python_tag|>"
self.bot_token_id = tokenizer.encode(self.bot_token,
add_special_tokens=False)[0]
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
# case -- if a tool call token is not present, return a text response
if not (model_output.startswith(self.bot_token)
or model_output.startswith('{')):
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
# load the JSON, and then use it to build the Function and
# Tool Call
dec = JSONDecoder()
function_call_arr = []
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = len(self.bot_token) if model_output.startswith(
self.bot_token) else 0
while start_idx < len(model_output):
(obj, end_idx) = dec.raw_decode(model_output[start_idx:])
start_idx += end_idx + len('; ')
function_call_arr.append(obj)
tool_calls: List[ToolCall] = [
ToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"] \
if "arguments" in raw_function_call \
else raw_function_call["parameters"])))
for raw_function_call in function_call_arr
]
# get any content before the tool call
ret = ExtractedToolCallInformation(tools_called=True,
tool_calls=tool_calls,
content=None)
return ret
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not (current_text.startswith(self.bot_token)
or current_text.startswith('{')):
return DeltaMessage(content=delta_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = len(self.bot_token) if current_text.startswith(
self.bot_token) else 0
while start_idx < len(current_text):
(obj,
end_idx) = partial_json_loads(current_text[start_idx:],
flags)
is_complete.append(
is_complete_json(current_text[start_idx:start_idx +
end_idx]))
start_idx += end_idx + len('; ')
# depending on the prompt Llama can use
# either arguments or parameters
if "parameters" in obj:
assert "arguments" not in obj, \
"model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
delta = None
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(
prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
import json import json
import re import re
from random import choices
from string import ascii_letters, digits
from typing import Dict, List, Sequence, Union from typing import Dict, List, Sequence, Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from pydantic import Field
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
ExtractedToolCallInformation, ExtractedToolCallInformation,
FunctionCall, ToolCall) FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser) ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import ( from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff) extract_intermediate_diff)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -19,7 +23,21 @@ from vllm.utils import random_uuid ...@@ -19,7 +23,21 @@ from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
class MistralToolCall(ToolCall):
id: str = Field(
default_factory=lambda: MistralToolCall.generate_random_id())
@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))
@ToolParserManager.register_module("mistral")
class MistralToolParser(ToolParser): class MistralToolParser(ToolParser):
""" """
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
...@@ -31,9 +49,7 @@ class MistralToolParser(ToolParser): ...@@ -31,9 +49,7 @@ class MistralToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer) super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer): if not isinstance(self.model_tokenizer, MistralTokenizer):
self.model_tokenizer = self.model_tokenizer.tokenizer
else:
logger.info("Non-Mistral tokenizer detected when using a Mistral " logger.info("Non-Mistral tokenizer detected when using a Mistral "
"model...") "model...")
...@@ -45,11 +61,18 @@ class MistralToolParser(ToolParser): ...@@ -45,11 +61,18 @@ class MistralToolParser(ToolParser):
self.streamed_args_for_tool: List[str] = [ self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list ] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]" self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.model_tokenizer.vocab[self.bot_token] self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
if not self.bot_token_id:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!")
def extract_tool_calls(self, def extract_tool_calls(
model_output: str) -> ExtractedToolCallInformation: self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
""" """
Extract the tool calls from a complete model response. Requires Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing, find-and-replacing single quotes with double quotes for JSON parsing,
...@@ -71,8 +94,8 @@ class MistralToolParser(ToolParser): ...@@ -71,8 +94,8 @@ class MistralToolParser(ToolParser):
# load the JSON, and then use it to build the Function and # load the JSON, and then use it to build the Function and
# Tool Call # Tool Call
function_call_arr = json.loads(raw_tool_call) function_call_arr = json.loads(raw_tool_call)
tool_calls: List[ToolCall] = [ tool_calls: List[MistralToolCall] = [
ToolCall( MistralToolCall(
type="function", type="function",
function=FunctionCall( function=FunctionCall(
name=raw_function_call["name"], name=raw_function_call["name"],
...@@ -88,8 +111,8 @@ class MistralToolParser(ToolParser): ...@@ -88,8 +111,8 @@ class MistralToolParser(ToolParser):
tool_calls=tool_calls, tool_calls=tool_calls,
content=content if len(content) > 0 else None) content=content if len(content) > 0 else None)
except Exception as e: except Exception:
logger.error("Error in extracting tool call from response: %s", e) logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON # return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False, return ExtractedToolCallInformation(tools_called=False,
tool_calls=[], tool_calls=[],
...@@ -103,6 +126,7 @@ class MistralToolParser(ToolParser): ...@@ -103,6 +126,7 @@ class MistralToolParser(ToolParser):
previous_token_ids: Sequence[int], previous_token_ids: Sequence[int],
current_token_ids: Sequence[int], current_token_ids: Sequence[int],
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]: ) -> Union[DeltaMessage, None]:
# if the tool call token is not in the tokens generated so far, append # if the tool call token is not in the tokens generated so far, append
...@@ -274,8 +298,8 @@ class MistralToolParser(ToolParser): ...@@ -274,8 +298,8 @@ class MistralToolParser(ToolParser):
self.prev_tool_call_arr = tool_call_arr self.prev_tool_call_arr = tool_call_arr
return delta return delta
except Exception as e: except Exception:
logger.error("Error trying to handle streaming tool call: %s", e) logger.exception("Error trying to handle streaming tool call.")
logger.debug( logger.debug(
"Skipping chunk as a result of tool streaming extraction " "Skipping chunk as a result of tool streaming extraction "
"error") "error")
......
...@@ -38,6 +38,7 @@ if TYPE_CHECKING: ...@@ -38,6 +38,7 @@ if TYPE_CHECKING:
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_OPENVINO_DEVICE: str = "CPU"
VLLM_OPENVINO_KVCACHE_SPACE: int = 0 VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
...@@ -65,7 +66,9 @@ if TYPE_CHECKING: ...@@ -65,7 +66,9 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_DISABLED_KERNELS: List[str] = []
def get_default_cache_root(): def get_default_cache_root():
...@@ -214,27 +217,12 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -214,27 +217,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")), ("true", "1")),
# If set, allowing the use of deprecated beam search implementation
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1",
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
lambda:
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")),
# Internal flag to control whether we use custom op,
# or use the native pytorch implementation
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS":
lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")),
# Internal flag to enable Dynamo fullgraph capture # Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool( lambda: bool(
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
...@@ -319,6 +307,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -319,6 +307,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_CPU_OMP_THREADS_BIND": "VLLM_CPU_OMP_THREADS_BIND":
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"), lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),
# OpenVINO device selection
# default is CPU
"VLLM_OPENVINO_DEVICE":
lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(),
# OpenVINO key-value cache space # OpenVINO key-value cache space
# default is 4GB # default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE": "VLLM_OPENVINO_KVCACHE_SPACE":
...@@ -413,6 +406,8 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -413,6 +406,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: lambda:
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")), ("1", "true")),
"VLLM_TEST_FORCE_LOAD_FORMAT":
lambda: os.getenv("VLLM_TEST_FORCE_LOAD_FORMAT", "dummy"),
# Time in ms for the zmq client to wait for a response from the backend # Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations # server for simple data operations
...@@ -441,6 +436,21 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -441,6 +436,21 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: lambda:
(os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in
("1", "true")), ("1", "true")),
# By default, vLLM will check the peer-to-peer capability itself,
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa
# If this env var is set to 1, vLLM will skip the peer-to-peer check,
# and trust the driver's peer-to-peer capability report.
"VLLM_SKIP_P2P_CHECK":
lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1",
# List of quantization kernels that should be disabled, used for testing
# and performance comparisons. Currently only affects MPLinearKernel
# selection
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
"VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -28,6 +28,8 @@ class CPUExecutor(ExecutorBase): ...@@ -28,6 +28,8 @@ class CPUExecutor(ExecutorBase):
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu" assert self.device_config.device_type == "cpu"
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
assert self.lora_config is None, "cpu backend doesn't support LoRA" assert self.lora_config is None, "cpu backend doesn't support LoRA"
# #
...@@ -324,6 +326,8 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: ...@@ -324,6 +326,8 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.float16: if config.dtype == torch.float16:
logger.warning("float16 is not supported on CPU, casting to bfloat16.") logger.warning("float16 is not supported on CPU, casting to bfloat16.")
config.dtype = torch.bfloat16 config.dtype = torch.bfloat16
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if not config.enforce_eager: if not config.enforce_eager:
logger.warning( logger.warning(
"CUDA graph is not supported on CPU, fallback to the eager " "CUDA graph is not supported on CPU, fallback to the eager "
...@@ -334,6 +338,8 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: ...@@ -334,6 +338,8 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
def _verify_and_get_scheduler_config( def _verify_and_get_scheduler_config(
config: SchedulerConfig) -> SchedulerConfig: config: SchedulerConfig) -> SchedulerConfig:
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if config.chunked_prefill_enabled: if config.chunked_prefill_enabled:
logger.warning("Chunked prefill is not supported on CPU, disable it.") logger.warning("Chunked prefill is not supported on CPU, disable it.")
config.chunked_prefill_enabled = False config.chunked_prefill_enabled = False
...@@ -342,6 +348,8 @@ def _verify_and_get_scheduler_config( ...@@ -342,6 +348,8 @@ def _verify_and_get_scheduler_config(
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if config.enable_prefix_caching: if config.enable_prefix_caching:
logger.warning("Prefix caching is not supported on CPU, disable it.") logger.warning("Prefix caching is not supported on CPU, disable it.")
config.enable_prefix_caching = False config.enable_prefix_caching = False
......
...@@ -56,6 +56,10 @@ class DistributedGPUExecutor(GPUExecutor): ...@@ -56,6 +56,10 @@ class DistributedGPUExecutor(GPUExecutor):
# have GPUs. # have GPUs.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks) num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len, max_concurrency)
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks
......
...@@ -121,6 +121,10 @@ class GPUExecutor(ExecutorBase): ...@@ -121,6 +121,10 @@ class GPUExecutor(ExecutorBase):
# remains to abstract away the device for non-GPU configurations. # remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks) num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len, max_concurrency)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
......
...@@ -15,8 +15,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput ...@@ -15,8 +15,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils import maybe_set_triton_cache_manager from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
get_distributed_init_method, get_open_port, cuda_is_initialized, get_distributed_init_method,
get_vllm_instance_id, make_async, get_open_port, get_vllm_instance_id, make_async,
update_environment_variables) update_environment_variables)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -122,6 +122,13 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -122,6 +122,13 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
}) })
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
cuda_device_count = cuda_device_count_stateless() cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case. # Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, ( assert tensor_parallel_size <= cuda_device_count, (
......
...@@ -3,7 +3,6 @@ import multiprocessing ...@@ -3,7 +3,6 @@ import multiprocessing
import os import os
import sys import sys
import threading import threading
import traceback
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import Queue from multiprocessing import Queue
...@@ -27,9 +26,6 @@ RESET = '\033[0;0m' ...@@ -27,9 +26,6 @@ RESET = '\033[0;0m'
JOIN_TIMEOUT_S = 2 JOIN_TIMEOUT_S = 2
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
mp = multiprocessing.get_context(mp_method)
@dataclass @dataclass
class Result(Generic[T]): class Result(Generic[T]):
...@@ -77,7 +73,7 @@ class ResultHandler(threading.Thread): ...@@ -77,7 +73,7 @@ class ResultHandler(threading.Thread):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(daemon=True) super().__init__(daemon=True)
self.result_queue = mp.Queue() self.result_queue = get_mp_context().Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
def run(self): def run(self):
...@@ -147,10 +143,11 @@ class ProcessWorkerWrapper: ...@@ -147,10 +143,11 @@ class ProcessWorkerWrapper:
def __init__(self, result_handler: ResultHandler, def __init__(self, result_handler: ResultHandler,
worker_factory: Callable[[], Any]) -> None: worker_factory: Callable[[], Any]) -> None:
self._task_queue = mp.Queue() self.mp = get_mp_context()
self._task_queue = self.mp.Queue()
self.result_queue = result_handler.result_queue self.result_queue = result_handler.result_queue
self.tasks = result_handler.tasks self.tasks = result_handler.tasks
self.process: BaseProcess = mp.Process( # type: ignore[attr-defined] self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined]
target=_run_worker_process, target=_run_worker_process,
name="VllmWorkerProcess", name="VllmWorkerProcess",
kwargs=dict( kwargs=dict(
...@@ -204,7 +201,7 @@ def _run_worker_process( ...@@ -204,7 +201,7 @@ def _run_worker_process(
"""Worker process event loop""" """Worker process event loop"""
# Add process-specific prefix to stdout and stderr # Add process-specific prefix to stdout and stderr
process_name = mp.current_process().name process_name = get_mp_context().current_process().name
pid = os.getpid() pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid) _add_prefix(sys.stderr, process_name, pid)
...@@ -229,10 +226,9 @@ def _run_worker_process( ...@@ -229,10 +226,9 @@ def _run_worker_process(
except KeyboardInterrupt: except KeyboardInterrupt:
break break
except BaseException as e: except BaseException as e:
tb = traceback.format_exc() logger.exception(
logger.error( "Exception in worker %s while processing method %s.",
"Exception in worker %s while processing method %s: %s, %s", process_name, method)
process_name, method, e, tb)
exception = e exception = e
result_queue.put( result_queue.put(
Result(task_id=task_id, value=output, exception=exception)) Result(task_id=task_id, value=output, exception=exception))
...@@ -268,4 +264,9 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: ...@@ -268,4 +264,9 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file.start_new_line = False # type: ignore[attr-defined] file.start_new_line = False # type: ignore[attr-defined]
file.start_new_line = True # type: ignore[attr-defined] file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign] file.write = write_with_prefix # type: ignore[method-assign]
\ No newline at end of file
def get_mp_context():
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)
...@@ -17,6 +17,14 @@ from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, ...@@ -17,6 +17,14 @@ from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
logger = init_logger(__name__) logger = init_logger(__name__)
def is_openvino_cpu() -> bool:
return "CPU" in envs.VLLM_OPENVINO_DEVICE
def is_openvino_gpu() -> bool:
return "GPU" in envs.VLLM_OPENVINO_DEVICE
class OpenVINOExecutor(ExecutorBase): class OpenVINOExecutor(ExecutorBase):
uses_ray: bool = False uses_ray: bool = False
...@@ -24,8 +32,13 @@ class OpenVINOExecutor(ExecutorBase): ...@@ -24,8 +32,13 @@ class OpenVINOExecutor(ExecutorBase):
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert self.device_config.device_type == "openvino" assert self.device_config.device_type == "openvino"
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA" assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
assert is_openvino_cpu() or is_openvino_gpu(), \
"OpenVINO backend supports only CPU and GPU devices"
self.ov_core = ov.Core()
self.model_config = _verify_and_get_model_config(self.model_config) self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config) self.cache_config = _verify_and_get_cache_config(
self.ov_core, self.cache_config)
# Instantiate the worker and load the model to CPU. # Instantiate the worker and load the model to CPU.
self._init_worker() self._init_worker()
...@@ -40,6 +53,7 @@ class OpenVINOExecutor(ExecutorBase): ...@@ -40,6 +53,7 @@ class OpenVINOExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
self.driver_worker = OpenVINOWorker( self.driver_worker = OpenVINOWorker(
ov_core=self.ov_core,
model_config=self.model_config, model_config=self.model_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
...@@ -68,10 +82,13 @@ class OpenVINOExecutor(ExecutorBase): ...@@ -68,10 +82,13 @@ class OpenVINOExecutor(ExecutorBase):
# NOTE: We log here to avoid multiple logs when number of workers is # NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors # greater than one. We could log in the engine, but not all executors
# have GPUs. # have GPUs.
# NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is # NOTE: In case of a CPU device, `cpu block` for OpenVINO backend
# referred as `gpu block`. Because we want to reuse the existing block # is located on CPU memory but is referred as `gpu block`.
# management procedure. # Because we want to reuse the existing block management procedure.
logger.info("# CPU blocks: %d", num_gpu_blocks) device_blocks = num_gpu_blocks
swap_blocks = num_cpu_blocks
logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d",
envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model( def execute_model(
...@@ -143,29 +160,45 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: ...@@ -143,29 +160,45 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
return config return config
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: def _verify_and_get_cache_config(ov_core: ov.Core,
config: CacheConfig) -> CacheConfig:
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
logger.info("KV cache type is overried to u8 via " if not is_openvino_cpu():
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
config.cache_dtype = ov.Type.u8 "ignored for GPU, f16 data type will be used.")
config.cache_dtype = ov.Type.f16
else:
logger.info("KV cache type is overridden to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
config.cache_dtype = ov.Type.u8
else: else:
core = ov.Core() if is_openvino_cpu():
inference_precision = core.get_property("CPU", ov_device = envs.VLLM_OPENVINO_DEVICE
hints.inference_precision) inference_precision = ov_core.get_property(
if inference_precision == ov.Type.bf16: ov_device, hints.inference_precision)
config.cache_dtype = ov.Type.bf16 if inference_precision == ov.Type.bf16:
config.cache_dtype = ov.Type.bf16
else:
config.cache_dtype = ov.Type.f16
else: else:
config.cache_dtype = ov.Type.f16 config.cache_dtype = ov.Type.f16
if config.block_size != 32: if is_openvino_cpu():
logger.info( if config.block_size != 32:
f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 logger.info(
) f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
config.block_size = 32 )
config.block_size = 32
else:
if config.block_size != 16:
logger.info(
f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 16
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0: if kv_cache_space >= 0:
if kv_cache_space == 0: if kv_cache_space == 0 and is_openvino_cpu():
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning( logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
......
from contextlib import contextmanager
from typing import Any
_forward_context: Any = None
def get_forward_context() -> Any:
"""Get the current forward context."""
return _forward_context
@contextmanager
def set_forward_context(context: Any):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
global _forward_context
prev_context = _forward_context
_forward_context = context
try:
yield
finally:
_forward_context = prev_context
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
TokensPrompt, build_explicit_enc_dec_prompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts) build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
...@@ -16,11 +17,14 @@ See also: ...@@ -16,11 +17,14 @@ See also:
__all__ = [ __all__ = [
"TextPrompt", "TextPrompt",
"TokensPrompt", "TokensPrompt",
"PromptInputs", "PromptType",
"SingletonPromptInputs", "SingletonPrompt",
"ExplicitEncoderDecoderPrompt", "ExplicitEncoderDecoderPrompt",
"LLMInputs", "TokenInputs",
"EncoderDecoderLLMInputs", "token_inputs",
"SingletonInputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"build_explicit_enc_dec_prompt", "build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list", "to_enc_dec_tuple_list",
"zip_enc_dec_prompts", "zip_enc_dec_prompts",
...@@ -28,3 +32,34 @@ __all__ = [ ...@@ -28,3 +32,34 @@ __all__ = [
"InputContext", "InputContext",
"InputRegistry", "InputRegistry",
] ]
def __getattr__(name: str):
import warnings
if name == "PromptInput":
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PromptType
if name == "LLMInputs":
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return DecoderOnlyInputs
if name == "EncoderDecoderLLMInputs":
msg = (
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return EncoderDecoderInputs
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple, from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Union) Optional, Tuple, Union, cast)
from typing_extensions import NotRequired, TypedDict, TypeVar from typing_extensions import NotRequired, TypedDict, TypeVar
...@@ -19,6 +19,14 @@ class TextPrompt(TypedDict): ...@@ -19,6 +19,14 @@ class TextPrompt(TypedDict):
if the model supports it. if the model supports it.
""" """
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
class TokensPrompt(TypedDict): class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt.""" """Schema for a tokenized prompt."""
...@@ -32,10 +40,18 @@ class TokensPrompt(TypedDict): ...@@ -32,10 +40,18 @@ class TokensPrompt(TypedDict):
if the model supports it. if the model supports it.
""" """
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
""" """
Set of possible schemas for a single LLM input: Set of possible schemas for a single prompt:
- A text prompt (:class:`str` or :class:`TextPrompt`) - A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`) - A tokenized prompt (:class:`TokensPrompt`)
...@@ -46,7 +62,7 @@ which may be utilized for encoder/decoder models when ...@@ -46,7 +62,7 @@ which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPromptInputs` may be employed A prompt of type :class:`SingletonPrompt` may be employed
as (1) input to a decoder-only model, (2) input to as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or where the decoder-prompt is not specified explicitly, or
...@@ -55,41 +71,44 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt` ...@@ -55,41 +71,44 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
""" """
_T1_co = TypeVar("_T1_co", _T1_co = TypeVar("_T1_co",
bound=SingletonPromptInputs, bound=SingletonPrompt,
default=SingletonPromptInputs, default=SingletonPrompt,
covariant=True) covariant=True)
_T2_co = TypeVar("_T2_co", _T2_co = TypeVar("_T2_co",
bound=SingletonPromptInputs, bound=SingletonPrompt,
default=SingletonPromptInputs, default=SingletonPrompt,
covariant=True) covariant=True)
# TODO: Make fields ReadOnly once mypy supports it # TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
"""Represents an encoder/decoder model input prompt, """
comprising an explicit encoder prompt and a Represents an encoder/decoder model input prompt,
decoder prompt. comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively, The encoder and decoder prompts, respectively, may be formatted
may formatted according to any of the according to any of the :class:`SingletonPrompt` schemas,
:class:`SingletonPromptInputs` schemas, and are not and are not required to have the same schema.
required to have the same schema.
Only the encoder prompt may have multi-modal data. Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model, be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt` and that the :code:`encoder_prompt` and :code:`decoder_prompt`
fields of this data structure themselves must be fields of this data structure themselves must be
:class:`SingletonPromptInputs` instances. :class:`SingletonPrompt` instances.
""" """
encoder_prompt: _T1_co encoder_prompt: _T1_co
decoder_prompt: Optional[_T2_co] decoder_prompt: Optional[_T2_co]
mm_processor_kwargs: NotRequired[Dict[str, Any]]
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
""" """
Set of possible schemas for an LLM input, including Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types: both decoder-only and encoder/decoder input types:
...@@ -101,13 +120,8 @@ both decoder-only and encoder/decoder input types: ...@@ -101,13 +120,8 @@ both decoder-only and encoder/decoder input types:
""" """
class LLMInputs(TypedDict): class TokenInputs(TypedDict):
""" """Represents token-based inputs."""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
prompt_token_ids: List[int] prompt_token_ids: List[int]
"""The token IDs of the prompt.""" """The token IDs of the prompt."""
...@@ -122,8 +136,49 @@ class LLMInputs(TypedDict): ...@@ -122,8 +136,49 @@ class LLMInputs(TypedDict):
if the model supports it. if the model supports it.
""" """
mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
def token_inputs(
prompt_token_ids: List[int],
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(prompt_token_ids=prompt_token_ids)
if prompt is not None:
inputs["prompt"] = prompt
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if mm_processor_kwargs is not None:
inputs["mm_processor_kwargs"] = mm_processor_kwargs
return inputs
class EncoderDecoderLLMInputs(LLMInputs): SingletonInputs = TokenInputs
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
DecoderOnlyInputs = TokenInputs
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
class EncoderDecoderInputs(TokenInputs):
""" """
The inputs in :class:`~vllm.LLMEngine` before they are The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor. passed to the model executor.
...@@ -146,33 +201,51 @@ class EncoderDecoderLLMInputs(LLMInputs): ...@@ -146,33 +201,51 @@ class EncoderDecoderLLMInputs(LLMInputs):
""" """
_T1 = TypeVar("_T1", _T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
bound=SingletonPromptInputs, _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
default=SingletonPromptInputs)
_T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
def build_explicit_enc_dec_prompt( def build_explicit_enc_dec_prompt(
encoder_prompt: _T1, encoder_prompt: _T1,
decoder_prompt: Optional[_T2], decoder_prompt: Optional[_T2],
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: ) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, if mm_processor_kwargs is None:
decoder_prompt=decoder_prompt) mm_processor_kwargs = {}
return ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt,
mm_processor_kwargs=mm_processor_kwargs)
def zip_enc_dec_prompts( def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1], enc_prompts: Iterable[_T1],
dec_prompts: Iterable[Optional[_T2]], dec_prompts: Iterable[Optional[_T2]],
mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]],
Dict[str, Any]]] = None,
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: ) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
""" """
Zip encoder and decoder prompts together into a list of Zip encoder and decoder prompts together into a list of
:class:`ExplicitEncoderDecoderPrompt` instances. :class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs
may also be provided; if a dict is passed, the same dictionary will be
used for every encoder/decoder prompt. If an iterable is provided, it will
be zipped with the encoder/decoder prompts.
""" """
if mm_processor_kwargs is None:
mm_processor_kwargs = cast(Dict[str, Any], {})
if isinstance(mm_processor_kwargs, dict):
return [
build_explicit_enc_dec_prompt(
encoder_prompt, decoder_prompt,
cast(Dict[str, Any], mm_processor_kwargs))
for (encoder_prompt,
decoder_prompt) in zip(enc_prompts, dec_prompts)
]
return [ return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts) mm_proc_kwargs)
for (encoder_prompt, decoder_prompt, mm_proc_kwargs
) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
] ]
...@@ -182,3 +255,34 @@ def to_enc_dec_tuple_list( ...@@ -182,3 +255,34 @@ def to_enc_dec_tuple_list(
return [(enc_dec_prompt["encoder_prompt"], return [(enc_dec_prompt["encoder_prompt"],
enc_dec_prompt["decoder_prompt"]) enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts] for enc_dec_prompt in enc_dec_prompts]
def __getattr__(name: str):
import warnings
if name == "PromptInput":
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PromptType
if name == "LLMInputs":
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return DecoderOnlyInputs
if name == "EncoderDecoderLLMInputs":
msg = (
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return EncoderDecoderInputs
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
from typing import List, Literal, Sequence, TypedDict, Union, overload from typing import List, Literal, Sequence, TypedDict, Union, cast, overload
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt,
TokensPrompt) TextPrompt, TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
...@@ -44,13 +44,16 @@ def parse_and_batch_prompt( ...@@ -44,13 +44,16 @@ def parse_and_batch_prompt(
if is_list_of(prompt, str): if is_list_of(prompt, str):
# case 2: array of strings # case 2: array of strings
prompt = cast(List[str], prompt)
return [ return [
ParsedText(content=elem, is_tokens=False) for elem in prompt ParsedText(content=elem, is_tokens=False) for elem in prompt
] ]
if is_list_of(prompt, int): if is_list_of(prompt, int):
# case 3: array of tokens # case 3: array of tokens
prompt = cast(List[int], prompt)
return [ParsedTokens(content=prompt, is_tokens=True)] return [ParsedTokens(content=prompt, is_tokens=True)]
if is_list_of(prompt, list): if is_list_of(prompt, list):
prompt = cast(List[List[int]], prompt)
if len(prompt[0]) == 0: if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt") raise ValueError("please provide at least one prompt")
...@@ -81,26 +84,26 @@ class ParsedTokensPrompt(TypedDict): ...@@ -81,26 +84,26 @@ class ParsedTokensPrompt(TypedDict):
def parse_singleton_prompt( def parse_singleton_prompt(
inputs: SingletonPromptInputs, prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
if isinstance(inputs, str): if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=inputs) return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(inputs, dict): elif isinstance(prompt, dict):
if "prompt_token_ids" in inputs: if "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens", return ParsedTokensPrompt(type="tokens",
content=inputs) # type: ignore content=prompt) # type: ignore
elif "prompt" in inputs: elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=inputs) return ParsedTextPrompt(type="text", content=prompt)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_explicit_encoder_decoder_prompt( def is_explicit_encoder_decoder_prompt(
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(inputs, dict) and "encoder_prompt" in inputs return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_valid_encoder_decoder_llm_inputs( def is_encoder_decoder_inputs(
inputs: Union[LLMInputs, EncoderDecoderLLMInputs], inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
) -> TypeIs[EncoderDecoderLLMInputs]: ) -> TypeIs[EncoderDecoderInputs]:
return "encoder_prompt_token_ids" in inputs return "encoder_prompt_token_ids" in inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment