Commit afd0da21 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.1' into v0.7.1-dev

parents 1a11f127 4f4d427a
...@@ -25,7 +25,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -25,7 +25,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T, IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest, RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse, RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest) RPCUProfileRequest)
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
...@@ -240,22 +243,34 @@ class MQLLMEngineClient(EngineClient): ...@@ -240,22 +243,34 @@ class MQLLMEngineClient(EngineClient):
queue = self.output_queues.get(request_id) queue = self.output_queues.get(request_id)
if queue is not None: if queue is not None:
queue.put_nowait(exception) queue.put_nowait(exception)
# Put each output into the appropriate queue.
elif isinstance(request_outputs, RPCAdapterLoadedResponse):
self._add_output(request_outputs)
else: else:
# Put each output into the appropriate steam.
for request_output in request_outputs: for request_output in request_outputs:
queue = self.output_queues.get( self._add_output(request_output)
request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient output handler.") logger.debug("Shutting down MQLLMEngineClient output handler.")
def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse]):
queue = self.output_queues.get(request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
async def setup(self): async def setup(self):
"""Setup the client before it starts sending server requests.""" """Setup the client before it starts sending server requests."""
# Start output_loop # Start output_loop
self.output_loop = asyncio.create_task(self.run_output_handler_loop()) if self.output_loop is None:
# only generate once to avoid multiple concurrent output_loops
# this will lead to race conditions and wrong orders of tokens
# returned by the engine
# setup will be called multiple times during the startup of
# the engine
self.output_loop = asyncio.create_task(
self.run_output_handler_loop())
with self.get_data_socket() as socket: with self.get_data_socket() as socket:
# Wait until server is ready. # Wait until server is ready.
...@@ -264,8 +279,9 @@ class MQLLMEngineClient(EngineClient): ...@@ -264,8 +279,9 @@ class MQLLMEngineClient(EngineClient):
self.tracing_flag = response.tracing_enabled self.tracing_flag = response.tracing_enabled
# Start health_loop. # Start health_loop.
self.health_loop = asyncio.create_task( if self.health_loop is None:
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
def close(self): def close(self):
"""Destroy the ZeroMQ Context.""" """Destroy the ZeroMQ Context."""
...@@ -659,3 +675,31 @@ class MQLLMEngineClient(EngineClient): ...@@ -659,3 +675,31 @@ class MQLLMEngineClient(EngineClient):
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""
await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
socket=self.input_socket)
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
request = RPCLoadAdapterRequest(lora_request)
# Create output queue for this requests.
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
self.output_queues[request.request_id] = queue
# Send the request
request_bytes = pickle.dumps(request)
await self.input_socket.send_multipart((request_bytes, ), copy=False)
# Wait for the response
request_output = await queue.get()
self.output_queues.pop(request.request_id)
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
...@@ -14,11 +14,13 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -14,11 +14,13 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest, RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse, RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest) RPCUProfileRequest)
# yapf: enable # yapf: enable
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -234,6 +236,10 @@ class MQLLMEngine: ...@@ -234,6 +236,10 @@ class MQLLMEngine:
self.start_profile() self.start_profile()
else: else:
self.stop_profile() self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
else: else:
raise ValueError("Unknown RPCRequest Type: " raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}") f"{type(request)}")
...@@ -284,6 +290,20 @@ class MQLLMEngine: ...@@ -284,6 +290,20 @@ class MQLLMEngine:
if self.log_requests: if self.log_requests:
logger.info("Aborted request %s.", request.request_id) logger.info("Aborted request %s.", request.request_id)
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
try:
self.engine.add_lora(request.lora_request)
except BaseException as e:
# Send back an error if the adater fails to load
rpc_err = RPCError(request_id=request.request_id,
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
return
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id))
def _health_check(self): def _health_check(self):
# Send unhealthy if engine has already errored # Send unhealthy if engine has already errored
if self._errored_with is not None: if self._errored_with is not None:
...@@ -296,7 +316,11 @@ class MQLLMEngine: ...@@ -296,7 +316,11 @@ class MQLLMEngine:
self._send_unhealthy(e) self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient.""" """Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if outputs: if outputs:
try: try:
from ray.exceptions import RayTaskError from ray.exceptions import RayTaskError
...@@ -335,16 +359,13 @@ class MQLLMEngine: ...@@ -335,16 +359,13 @@ class MQLLMEngine:
self._errored_with = e self._errored_with = e
def start_profile(self) -> None: def start_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor: self.engine.start_profile()
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
def stop_profile(self) -> None: def stop_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor: self.engine.stop_profile()
self.engine.model_executor.stop_profile()
else: def reset_prefix_cache(self) -> bool:
self.engine.model_executor._run_workers("stop_profile") return self.engine.reset_prefix_cache()
def signal_handler(*_) -> None: def signal_handler(*_) -> None:
......
...@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@staticmethod @staticmethod
@functools.lru_cache @functools.lru_cache
def _log_prompt_logprob_unsupported_warning_once(): def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/source/usage/compatibility_matrix.md # Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
logger.warning( logger.warning(
"Prompt logprob is not supported by multi step workers. " "Prompt logprob is not supported by multi step workers. "
...@@ -144,7 +144,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -144,7 +144,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def _process_decode_and_stop(self, seq: Sequence, def _process_decode_and_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None: sampling_params: SamplingParams) -> None:
new_char_count = 0 new_char_count = 0
if sampling_params.detokenize: if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace( new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params) seq, sampling_params)
......
...@@ -102,9 +102,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -102,9 +102,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
Args: Args:
seq_group: the output is associated with this :class:`SequenceGroup` seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step outputs: the :class:`SequenceGroupOutput` for a single scheduler step
""" """
assert len(outputs) == 1, ("Single step should only has 1 output.") assert len(outputs) == 1, "Single step should only have 1 output."
output = outputs[0] output = outputs[0]
assert isinstance(output, CompletionSequenceGroupOutput) assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output) single_step_process_prompt_logprob(self, seq_group, output)
......
...@@ -270,3 +270,13 @@ class EngineClient(ABC): ...@@ -270,3 +270,13 @@ class EngineClient(ABC):
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
"""Start profiling the engine""" """Start profiling the engine"""
... ...
@abstractmethod
async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""
...
@abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
...
...@@ -3,10 +3,10 @@ import codecs ...@@ -3,10 +3,10 @@ import codecs
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
from functools import lru_cache, partial from functools import cache, lru_cache, partial
from pathlib import Path from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) Literal, Optional, Tuple, TypeVar, Union, cast)
import jinja2.nodes import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils import transformers.utils.chat_template_utils as hf_chat_utils
...@@ -23,6 +23,8 @@ from openai.types.chat import ( ...@@ -23,6 +23,8 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam, from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam) ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
# yapf: enable # yapf: enable
# pydantic needs the TypedDict from typing_extensions # pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
...@@ -31,13 +33,8 @@ from typing_extensions import Required, TypeAlias, TypedDict ...@@ -31,13 +33,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio, from vllm.multimodal.utils import MediaConnector
async_get_and_parse_image,
async_get_and_parse_video,
get_and_parse_audio, get_and_parse_image,
get_and_parse_video)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -368,16 +365,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -368,16 +365,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {}) if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._items: List[_T] = [] self._items_by_modality = defaultdict[str, list[_T]](list)
@property @property
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
return self._model_config return self._model_config
@property
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@staticmethod @staticmethod
@lru_cache(maxsize=None) @cache
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
return tokenizer.decode(token_index) return tokenizer.decode(token_index)
...@@ -392,7 +392,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -392,7 +392,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type == "phi3_v": if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer # Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>" return f"<|image_{current_count}|>"
if model_type == "minicpmv": if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)" return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
"pixtral"): "pixtral"):
...@@ -403,8 +403,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -403,8 +403,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"): if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer, return self._cached_token_str(self._tokenizer,
hf_config.image_token_index) hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat", "NVLM_D", if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
"h2ovl_chat"): "NVLM_D", "h2ovl_chat"):
return "<image>" return "<image>"
if model_type == "mllama": if model_type == "mllama":
return "<|image|>" return "<|image|>"
...@@ -424,10 +424,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -424,10 +424,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type == "qwen2_audio": if model_type == "qwen2_audio":
return (f"Audio {current_count}: " return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>") f"<|audio_bos|><|AUDIO|><|audio_eos|>")
if model_type == "minicpmo":
return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video": elif modality == "video":
if model_type == "qwen2_vl": if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>" return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"):
return "(<video>./</video>)"
if model_type.startswith("llava"): if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer, return self._cached_token_str(self._tokenizer,
hf_config.video_token_index) hf_config.video_token_index)
...@@ -435,38 +439,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -435,38 +439,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else: else:
raise TypeError(f"Unknown modality: {modality}") raise TypeError(f"Unknown modality: {modality}")
@staticmethod
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in items:
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)
# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}
def add(self, modality: ModalityStr, item: _T) -> Optional[str]: def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
""" """
Add a multi-modal item to the current prompt and returns the Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any. placeholder string to use, if any.
""" """
allowed_count = self._allowed_items.get(modality, 1) allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1 current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count: if current_count > allowed_count:
raise ValueError( raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in " f"At most {allowed_count} {modality}(s) may be provided in "
"one request.") "one request.")
self._consumed_items[modality] = current_count self._items_by_modality[modality].append(item)
self._items.append(item)
return self._placeholder_str(modality, current_count) return self._placeholder_str(modality, current_count)
...@@ -475,22 +460,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -475,22 +460,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise NotImplementedError raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]): class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]: def all_mm_data(self) -> Optional[MultiModalDataDict]:
return self._combine(self._items) if self._items else None if self._items_by_modality:
return dict(self._items_by_modality)
return None
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self) return MultiModalContentParser(self)
class AsyncMultiModalItemTracker( class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]: async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items: if self._items_by_modality:
items = await asyncio.gather(*self._items) return {
return self._combine(items) modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
return None return None
...@@ -522,7 +511,7 @@ class BaseMultiModalContentParser(ABC): ...@@ -522,7 +511,7 @@ class BaseMultiModalContentParser(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_input_audio(self, input_audio: Dict[str, str]) -> None: def parse_input_audio(self, input_audio: InputAudio) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -537,31 +526,31 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -537,31 +526,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url, image = self._connector.fetch_image(image_url)
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)
placeholder = self._tracker.add("image", image) placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
audio = get_and_parse_audio(audio_url) audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio) placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None: def parse_input_audio(self, input_audio: InputAudio) -> None:
input_audio_data = input_audio.get("data","") audio_data = input_audio.get("data", "")
input_audio_format = input_audio.get("format","") audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" audio_url = f"data:audio/{audio_format};base64,{audio_data}"
audio = get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio) return self.parse_audio(audio_url)
self._add_placeholder(placeholder)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url) video = self._connector.fetch_video(video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
...@@ -573,33 +562,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -573,33 +562,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image( image_coro = self._connector.fetch_image_async(image_url)
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)
placeholder = self._tracker.add("image", image_coro) placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
audio_coro = async_get_and_parse_audio(audio_url) audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro) placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None: def parse_input_audio(self, input_audio: InputAudio) -> None:
input_audio_data = input_audio.get("data","") audio_data = input_audio.get("data", "")
input_audio_format = input_audio.get("format","") audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" audio_url = f"data:audio/{audio_format};base64,{audio_data}"
audio_coro = async_get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio_coro) return self.parse_audio(audio_url)
self._add_placeholder(placeholder)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url) video = self._connector.fetch_video_async(video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
...@@ -695,10 +682,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) ...@@ -695,10 +682,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
# Define a mapping from part types to their corresponding parsing functions. # Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, MM_PARSER_MAP: Dict[
Callable[[ChatCompletionContentPartParam], str,
Union[str, Dict[str,str]]]] = { Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
"text": "text":
lambda part: _TextParser(part).get("text", ""), lambda part: _TextParser(part).get("text", ""),
"image_url": "image_url":
...@@ -715,8 +705,7 @@ MM_PARSER_MAP: Dict[str, ...@@ -715,8 +705,7 @@ MM_PARSER_MAP: Dict[str,
def _parse_chat_message_content_mm_part( def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
Union[str, Dict[str, str]]]:
""" """
Parses a given multi-modal content part based on its type. Parses a given multi-modal content part based on its type.
...@@ -783,7 +772,7 @@ def _parse_chat_message_content_parts( ...@@ -783,7 +772,7 @@ def _parse_chat_message_content_parts(
*, *,
wrap_dicts: bool, wrap_dicts: bool,
) -> List[ConversationMessage]: ) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = [] content = list[_ContentPart]()
mm_parser = mm_tracker.create_parser() mm_parser = mm_tracker.create_parser()
...@@ -814,7 +803,7 @@ def _parse_chat_message_content_part( ...@@ -814,7 +803,7 @@ def _parse_chat_message_content_part(
mm_parser: BaseMultiModalContentParser, mm_parser: BaseMultiModalContentParser,
*, *,
wrap_dicts: bool, wrap_dicts: bool,
) -> Optional[Union[str, Dict[str, str]]]: ) -> Optional[_ContentPart]:
"""Parses a single part of a conversation. If wrap_dicts is True, """Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
...@@ -823,8 +812,7 @@ def _parse_chat_message_content_part( ...@@ -823,8 +812,7 @@ def _parse_chat_message_content_part(
with multimodal placeholders. with multimodal placeholders.
""" """
if isinstance(part, str): # Handle plain text parts if isinstance(part, str): # Handle plain text parts
text = _TextParser(part) return part
return text
# Handle structured dictionary parts # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part) part_type, content = _parse_chat_message_content_mm_part(part)
...@@ -855,7 +843,7 @@ def _parse_chat_message_content_part( ...@@ -855,7 +843,7 @@ def _parse_chat_message_content_part(
return {'type': 'audio'} if wrap_dicts else None return {'type': 'audio'} if wrap_dicts else None
if part_type == "input_audio": if part_type == "input_audio":
dict_content = cast(Dict[str, str], content) dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content) mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None return {'type': 'audio'} if wrap_dicts else None
...@@ -1000,14 +988,14 @@ def apply_mistral_chat_template( ...@@ -1000,14 +988,14 @@ def apply_mistral_chat_template(
**kwargs: Any, **kwargs: Any,
) -> List[int]: ) -> List[int]:
if chat_template is not None: if chat_template is not None:
print_warning_once( logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.") "'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs: if "add_generation_prompt" in kwargs:
print_warning_once( logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, " "'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.") "so it will be ignored.")
if "continue_final_message" in kwargs: if "continue_final_message" in kwargs:
print_warning_once( logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, " "'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.") "so it will be ignored.")
......
import itertools import itertools
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Union, cast, overload) Tuple, Type, Union, cast, overload)
import cloudpickle
import torch
import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import deprecated from typing_extensions import TypeVar, deprecated
from vllm import envs from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
...@@ -21,7 +24,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ...@@ -21,7 +24,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
parse_chat_messages, parse_chat_messages,
resolve_chat_template_content_format) resolve_chat_template_content_format)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
...@@ -41,6 +44,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of ...@@ -41,6 +44,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger = init_logger(__name__) logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLM: class LLM:
"""An LLM for generating texts from given prompts and sampling parameters. """An LLM for generating texts from given prompts and sampling parameters.
...@@ -186,6 +191,13 @@ class LLM: ...@@ -186,6 +191,13 @@ class LLM:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
if "worker_cls" in kwargs:
worker_cls = kwargs["worker_cls"]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if isinstance(worker_cls, type):
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
if compilation_config is not None: if compilation_config is not None:
if isinstance(compilation_config, (int, dict)): if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli( compilation_config_instance = CompilationConfig.from_cli(
...@@ -225,18 +237,11 @@ class LLM: ...@@ -225,18 +237,11 @@ class LLM:
# Logic to switch between engines is done at runtime instead of import # Logic to switch between engines is done at runtime instead of import
# to avoid import order issues # to avoid import order issues
self.engine_class = self.get_engine_class() self.engine_class = self.get_engine_class()
# TODO(rob): enable mp by default (issue with fork vs spawn)
self.llm_engine = self.engine_class.from_engine_args( self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS) engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter() self.request_counter = Counter()
def __del__(self):
if hasattr(self, 'llm_engine') and self.llm_engine and hasattr(
self.llm_engine, "shutdown"):
self.llm_engine.shutdown()
@staticmethod @staticmethod
def get_engine_class() -> Type[LLMEngine]: def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
...@@ -462,9 +467,47 @@ class LLM: ...@@ -462,9 +467,47 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput) return self.engine_class.validate_outputs(outputs, RequestOutput)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
executor = self.llm_engine.model_executor
return executor.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
executor = self.llm_engine.model_executor
return executor.apply_model(func)
def beam_search( def beam_search(
self, self,
prompts: List[Union[str, List[int]]], prompts: List[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams, params: BeamSearchParams,
) -> List[BeamSearchOutput]: ) -> List[BeamSearchOutput]:
""" """
...@@ -500,8 +543,10 @@ class LLM: ...@@ -500,8 +543,10 @@ class LLM:
instances: List[BeamSearchInstance] = [] instances: List[BeamSearchInstance] = []
for prompt in prompts: for prompt in prompts:
prompt_tokens = prompt if isinstance( if is_token_prompt(prompt):
prompt, list) else tokenizer.encode(prompt) prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(BeamSearchInstance(prompt_tokens)) instances.append(BeamSearchInstance(prompt_tokens))
for _ in range(max_tokens): for _ in range(max_tokens):
...@@ -952,6 +997,107 @@ class LLM: ...@@ -952,6 +997,107 @@ class LLM:
return [ClassificationRequestOutput.from_base(item) for item in items] return [ClassificationRequestOutput.from_base(item) for item in items]
def _embedding_score(
self,
tokenizer: AnyTokenizer,
text_1: List[Union[str, TextPrompt, TokensPrompt]],
text_2: List[Union[str, TextPrompt, TokensPrompt]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ScoringRequestOutput]:
encoded_output = self.encode(
text_1 + text_2,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
encoded_output_1 = encoded_output[0:len(text_1)]
encoded_output_2 = encoded_output[len(text_1):]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
output_pairs = [(t1, t2)
for t1, t2 in zip(encoded_output_1, encoded_output_2)]
scores = []
scorer = torch.nn.CosineSimilarity(0)
for embed_1, embed_2 in output_pairs:
pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)
if (pad_token_id := getattr(tokenizer, "pad_token_id",
None)) is not None:
tokens = embed_1.prompt_token_ids + [
pad_token_id
] + embed_2.prompt_token_ids
else:
tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{embed_1.request_id}_{embed_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
finished=True))
items = self.engine_class.validate_outputs(scores,
PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
text_1: List[Union[str, TextPrompt, TokensPrompt]],
text_2: List[Union[str, TextPrompt, TokensPrompt]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ScoringRequestOutput]:
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"Score API is only enabled for `--task embed or score`")
if len(text_1) == 1:
text_1 = text_1 * len(text_2)
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
pooling_params = PoolingParams()
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
parsed_prompts = []
for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def score( def score(
self, self,
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]], text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
...@@ -1003,25 +1149,20 @@ class LLM: ...@@ -1003,25 +1149,20 @@ class LLM:
raise ValueError(" ".join(messages)) raise ValueError(" ".join(messages))
if not self.llm_engine.model_config.is_cross_encoder: if self.llm_engine.model_config.task not in ("embed", "score"):
raise ValueError("Your model does not support cross encoding")
if self.llm_engine.model_config.task != "score":
raise ValueError("Score API is only enabled for `--task score`")
tokenizer = self.llm_engine.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
raise ValueError( raise ValueError(
"MistralTokenizer not supported for cross-encoding") "Score API is only enabled for `--task embed or --task score`")
# the tokenizer for models such as # the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs # lists of tokens to the `text` and `text_pair` kwargs
tokenizer = self.llm_engine.get_tokenizer()
def ensure_str(prompt: SingletonPrompt): def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict): if isinstance(prompt, dict):
if "multi_modal_data" in prompt: if "multi_modal_data" in prompt:
raise ValueError("Multi-modal prompt is not " raise ValueError("Multi-modal prompt is not "
"supported for cross encoding") "supported for scoring")
elif "prompt_token_ids" in prompt: elif "prompt_token_ids" in prompt:
prompt = tokenizer.decode( prompt = tokenizer.decode(
cast(TokensPrompt, prompt)["prompt_token_ids"]) cast(TokensPrompt, prompt)["prompt_token_ids"])
...@@ -1047,40 +1188,15 @@ class LLM: ...@@ -1047,40 +1188,15 @@ class LLM:
if len(text_2) == 0: if len(text_2) == 0:
raise ValueError("At least one text_pair element must be given") raise ValueError("At least one text_pair element must be given")
if len(text_1) == 1: if self.llm_engine.model_config.is_cross_encoder:
text_1 = text_1 * len(text_2) return self._cross_encoding_score(tokenizer, text_1, text_2,
truncate_prompt_tokens, use_tqdm,
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] lora_request,
pooling_params = PoolingParams() prompt_adapter_request)
else:
tokenization_kwargs: Dict[str, Any] = {} return self._embedding_score(tokenizer, text_1, text_2,
if truncate_prompt_tokens is not None: truncate_prompt_tokens, use_tqdm,
tokenization_kwargs["truncation"] = True lora_request, prompt_adapter_request)
tokenization_kwargs["max_length"] = truncate_prompt_tokens
parsed_prompts = []
for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def start_profile(self) -> None: def start_profile(self) -> None:
self.llm_engine.start_profile() self.llm_engine.start_profile()
...@@ -1088,6 +1204,36 @@ class LLM: ...@@ -1088,6 +1204,36 @@ class LLM:
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.llm_engine.stop_profile() self.llm_engine.stop_profile()
def reset_prefix_cache(self) -> bool:
return self.llm_engine.reset_prefix_cache()
def sleep(self, level: int = 1):
"""
Put the engine to sleep. The engine should not process any requests.
The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache is
forgotten. Level 1 sleep is good for sleeping and waking up the
engine to run the same model again. The model weights are backed
up in CPU memory. Please make sure there's enough CPU memory to
store the model weights. Level 2 sleep will discard both the model
weights and the kv cache. The content of both the model weights
and kv cache is forgotten. Level 2 sleep is good for sleeping and
waking up the engine to run a different model or update the model,
where previous model weights are not needed. It reduces CPU memory
pressure.
"""
self.reset_prefix_cache()
self.llm_engine.sleep(level=level)
def wake_up(self):
"""
Wake up the engine from sleep mode. See the :meth:`sleep` method
for more details."""
self.llm_engine.wake_up()
# LEGACY # LEGACY
def _convert_v1_inputs( def _convert_v1_inputs(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment