Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
......@@ -8,7 +8,7 @@ from typing import Tuple, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
......@@ -18,9 +18,12 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse, UsageInfo)
ErrorResponse,
RequestResponseMetadata,
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
from vllm.logger import init_logger
......@@ -43,18 +46,18 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(async_engine_client=async_engine_client,
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger,
......@@ -78,15 +81,25 @@ class OpenAIServingCompletion(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
model_name = self.served_model_names[0]
model_name = self.base_model_paths[0].name
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
......@@ -95,8 +108,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
......@@ -124,8 +136,8 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = (
await self.async_engine_client.is_tracing_enabled())
is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers)
......@@ -133,7 +145,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers):
log_tracing_disabled_warning()
generator = self.async_engine_client.generate(
generator = self.engine_client.generate(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params,
request_id_item,
......@@ -159,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
if stream:
return self.completion_stream_generator(request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts),
tokenizer=tokenizer)
return self.completion_stream_generator(
request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts),
tokenizer=tokenizer,
request_metadata=request_metadata)
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
......@@ -192,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
......@@ -221,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name: str,
num_prompts: int,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
......@@ -340,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing):
exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
......@@ -354,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
......@@ -427,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens=num_prompt_tokens + num_generated_tokens,
)
request_metadata.final_usage_info = usage
return CompletionResponse(
id=request_id,
created=created_time,
......
......@@ -8,13 +8,13 @@ from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
......@@ -71,15 +71,15 @@ class OpenAIServingEmbedding(OpenAIServing):
def __init__(
self,
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
base_model_paths: List[BaseModelPath],
*,
request_logger: Optional[RequestLogger],
):
super().__init__(async_engine_client=async_engine_client,
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
......@@ -118,8 +118,7 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
pooling_params = request.to_pooling_params()
......@@ -144,7 +143,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"Prompt adapter is not supported "
"for embedding models")
generator = self.async_engine_client.encode(
generator = self.engine_client.encode(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
pooling_params,
request_id_item,
......
......@@ -8,7 +8,7 @@ from pydantic import Field
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
......@@ -39,6 +39,12 @@ from vllm.utils import AtomicCounter
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class PromptAdapterPath:
name: str
......@@ -49,6 +55,7 @@ class PromptAdapterPath:
class LoRAModulePath:
name: str
path: str
base_model_name: Optional[str] = None
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
......@@ -64,9 +71,9 @@ class OpenAIServing:
def __init__(
self,
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
......@@ -75,21 +82,24 @@ class OpenAIServing:
):
super().__init__()
self.async_engine_client = async_engine_client
self.engine_client = engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.served_model_names = served_model_names
self.base_model_paths = base_model_paths
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
) for i, lora in enumerate(lora_modules, start=1)
LoRARequest(lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
base_model_name=lora.base_model_name
if lora.base_model_name
and self._is_model_supported(lora.base_model_name)
else self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
......@@ -112,21 +122,23 @@ class OpenAIServing:
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=served_model_name,
ModelCard(id=base_model.name,
max_model_len=self.max_model_len,
root=self.served_model_names[0],
root=base_model.model_path,
permission=[ModelPermission()])
for served_model_name in self.served_model_names
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model_names[0],
root=lora.local_path,
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.served_model_names[0],
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
......@@ -159,7 +171,7 @@ class OpenAIServing:
async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.async_engine_client.get_decoding_config()
decoding_config = await self.engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
......@@ -169,7 +181,7 @@ class OpenAIServing:
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
if self._is_model_supported(request.model):
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
......@@ -187,7 +199,7 @@ class OpenAIServing:
self, request: AnyRequest
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
None, PromptAdapterRequest]]:
if request.model in self.served_model_names:
if self._is_model_supported(request.model):
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
......@@ -480,3 +492,6 @@ class OpenAIServing:
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
......@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import MistralTokenizer
......@@ -29,17 +30,17 @@ class OpenAIServingTokenization(OpenAIServing):
def __init__(
self,
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(async_engine_client=async_engine_client,
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=None,
request_logger=request_logger)
......@@ -66,7 +67,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompt: Union[str, List[int]]
if isinstance(request, TokenizeChatRequest):
......@@ -132,7 +133,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
self._log_inputs(request_id,
request.tokens,
......
......@@ -59,10 +59,12 @@ if TYPE_CHECKING:
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_RPC_TIMEOUT: int = 10000 # ms
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False
def get_default_cache_root():
......@@ -206,6 +208,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")),
# If set, allowing the use of deprecated beam search implementation
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1",
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
......@@ -214,6 +220,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")),
# Internal flag to control whether we use custom op,
# or use the native pytorch implementation
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS":
lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool(
......@@ -399,8 +410,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
"VLLM_RPC_TIMEOUT":
lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")),
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
......
......@@ -106,6 +106,7 @@ class CPUExecutor(ExecutorBase):
)) for rank in range(1, world_size)
]
self.worker_monitor = None
if world_size != 1 or is_async:
if is_async:
async_worker_list = self.workers + [self.driver_worker]
......
import asyncio
import os
import signal
import threading
import weakref
from functools import partial
from typing import Any, List, Optional
......@@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref = weakref.ref(self)
def shutdown(signum, frame):
if executor := ref():
executor.shutdown()
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")
......
......@@ -76,8 +76,7 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
def __init__(self) -> None:
super().__init__(daemon=False)
# super().__init__(daemon=True)
super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
......@@ -101,8 +100,7 @@ class WorkerMonitor(threading.Thread):
def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler):
super().__init__(daemon=False)
# super().__init__(daemon=True)
super().__init__(daemon=True)
self.workers = workers
self.result_handler = result_handler
self._close = False
......@@ -114,30 +112,16 @@ class WorkerMonitor(threading.Thread):
self._close = True
# Kill / cleanup all workers
# for worker in self.workers:
# process = worker.process
# if process.sentinel in dead_sentinels:
# process.join(JOIN_TIMEOUT_S)
# if process.exitcode is not None and process.exitcode != 0:
# logger.error("Worker %s pid %s died, exit code: %s",
# process.name, process.pid, process.exitcode)
if not sys.is_finalizing():
# Kill / cleanup all workers
died_count = 0
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
died_count += 1
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
if died_count < len(self.workers):
logger.info(
"Killing remaining local vLLM worker processes")
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
# Cleanup any remaining workers
# logger.info("Killing local vLLM worker processes")
if logger:
logger.info("Killing local vLLM worker processes")
for worker in self.workers:
worker.kill_worker()
# Must be done after worker task queues are all closed
......@@ -184,6 +168,8 @@ class ProcessWorkerWrapper:
self.tasks[task_id] = future
try:
self._task_queue.put((task_id, method, args, kwargs))
except SystemExit:
raise
except BaseException as e:
del self.tasks[task_id]
raise ChildProcessError("worker died") from e
......@@ -238,6 +224,10 @@ def _run_worker_process(
try:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
except SystemExit:
raise
except KeyboardInterrupt:
break
except BaseException as e:
tb = traceback.format_exc()
logger.error(
......@@ -278,4 +268,4 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file.start_new_line = False # type: ignore[attr-defined]
file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign]
file.write = write_with_prefix # type: ignore[method-assign]
\ No newline at end of file
......@@ -437,8 +437,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
required_version = version.parse("2.35")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
# TODO: update the constraint once we adapt to the backward
# incompatible API change from ray 2.36
if current_version != required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")
import importlib.util
......
......@@ -26,6 +26,8 @@ logger = init_logger(__name__)
class RayTPUExecutor(TPUExecutor):
uses_ray: bool = True
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
......@@ -68,8 +70,12 @@ class RayTPUExecutor(TPUExecutor):
)
assert self.speculative_config is None
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"
if self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_tpu_worker"
worker_class_name = "MultiStepTPUWorker"
else:
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"
# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
......
......@@ -18,9 +18,14 @@ PG_WAIT_TIMEOUT = 1800
try:
import ray
from ray._private.state import available_resources_per_node
from ray.util import placement_group_table
from ray.util.placement_group import PlacementGroup
try:
from ray._private.state import available_resources_per_node
except ImportError:
# Ray 2.9.x doesn't expose `available_resources_per_node`
from ray._private.state import state as _state
available_resources_per_node = _state._available_resources_per_node
class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
......
......@@ -62,11 +62,17 @@ class TPUExecutor(ExecutorBase):
rank: int = 0,
distributed_init_method: Optional[str] = None,
):
from vllm.worker.tpu_worker import TPUWorker
worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return worker
if self.scheduler_config.is_multi_step:
from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker
worker = MultiStepTPUWorker(**self._get_worker_kwargs(
local_rank, rank, distributed_init_method))
return worker
else:
from vllm.worker.tpu_worker import TPUWorker
worker = TPUWorker(**self._get_worker_kwargs(
local_rank, rank, distributed_init_method))
return worker
def initialize_cache(
self,
......
......@@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
available.
"""
encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the encoder model,
if the model supports it.
"""
_T1 = TypeVar("_T1",
bound=SingletonPromptInputs,
......
......@@ -128,6 +128,7 @@ class InputPreprocessor:
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]],
force_bos: bool = True,
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
......@@ -157,8 +158,8 @@ class InputPreprocessor:
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
if force_bos and (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
......@@ -295,18 +296,25 @@ class InputPreprocessor:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
if decoder_mm_data is not None:
raise ValueError(
"Multi-modality decoder inputs of encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
# For Multi-Modal models (e.g., mllama), the text input can be
# <|image|><|begin_of_text|>hello world. And we should not add
# another <|begin_of_text|> to the beginning.
decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation(
decoder_prompt_ids,
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
multi_modal_data=decoder_mm_data,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
encoder_multi_modal_data=encoder_mm_data,
)
def _process_encoder_decoder_prompt(
......
import functools
from array import array
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
......@@ -10,6 +9,7 @@ from transformers import PretrainedConfig
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.utils import get_allowed_kwarg_only_overrides
from .data import LLMInputs
......@@ -22,10 +22,6 @@ logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@dataclass(frozen=True)
class InputContext:
......@@ -73,12 +69,17 @@ class DummyDataFactory(Protocol):
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
**mm_processor_kwargs: Any,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
The :code:`mm_processor_kwargs` are overrides provided at
initialization time to values in the config whose values
may affect the number of tokens per instance.
"""
...
......@@ -111,6 +112,8 @@ class InputRegistry:
def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {}
self._dummy_encoder_factories_by_model_type: Dict[
Type[nn.Module], DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {}
......@@ -130,8 +133,7 @@ class InputRegistry:
# Avoid circular import
from vllm.sequence import SequenceData
dummy_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
dummy_seq_data = SequenceData.from_token_counts((0, seq_len))
dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data
......@@ -158,11 +160,48 @@ class InputRegistry:
return wrapper
def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
def register_dummy_encoder_data(self, factory: DummyDataFactory):
"""
Register a dummy encoder data factory to a model class
This is similar to :meth:`~register_dummy_data`, but for encoder input.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_encoder_factories_by_model_type:
logger.warning(
"Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_encoder_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
if model_cls in self._dummy_encoder_factories_by_model_type:
dummy_factory = self._dummy_encoder_factories_by_model_type[
model_cls]
else:
logger.warning(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead.",
model_cls)
dummy_factory = self._get_dummy_data_factory(model_cls)
return dummy_factory
def dummy_data_for_profiling(
self,
model_config: "ModelConfig",
seq_len: int,
mm_registry: "MultiModalRegistry",
is_encoder_data: bool = False,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data for profiling the memory usage of a model.
......@@ -180,22 +219,29 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
dummy_factory = self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
if is_encoder_data:
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
else:
dummy_factory = self._get_dummy_data_factory(model_cls)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory, overrides=model_config.mm_processor_kwargs)
seq_data, mm_data = dummy_factory(
InputContext(model_config),
seq_len,
_MultiModalCounts(mm_counts),
)
seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len,
_MultiModalCounts(mm_counts),
**mm_processor_kwargs)
# Having more tokens is over-conservative but otherwise fine
num_tokens = seq_data.prompt_token_ids
assert len(num_tokens) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if len(num_tokens) < seq_len:
if is_encoder_data:
logger.warning(
"Expected at least %d dummy encoder tokens for profiling, "
"but found %d tokens instead.", seq_len, len(num_tokens))
else:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if mm_data is not None:
for k, v in mm_data.items():
num_items = len(v) if isinstance(v, list) else 1
......@@ -235,6 +281,10 @@ class InputRegistry:
return wrapper
def _get_model_input_processor(self, model_cls: Type[nn.Module]):
return self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)
def process_input(self, model_config: "ModelConfig",
inputs: LLMInputs) -> LLMInputs:
"""
......@@ -249,15 +299,17 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
processor = self._get_model_input_processor(model_cls)
processor = self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
processor, overrides=model_config.mm_processor_kwargs)
return processor(InputContext(model_config), inputs)
return processor(InputContext(model_config), inputs,
**mm_processor_kwargs)
def create_input_processor(self, model_config: "ModelConfig"):
"""
Create an input processor (see :meth:`process_input`) for a
Create an input processor (see :meth:`_process_input`) for a
specific model.
"""
return functools.partial(self.process_input, model_config)
......@@ -100,7 +100,7 @@ def _bgmv_expand(
corresponding to each batch, An index of -1 means no lora should be
applied.
batches (int): batch size
add_inputs (bool, optional): Defaults to False. adds the final lora
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
......
......@@ -104,7 +104,7 @@ def _bgmv_expand_slice(
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
slice_offst (int): output_tensor's offst
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
......
......@@ -106,6 +106,7 @@ def _sgmv_expand(
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> None:
"""
......@@ -115,17 +116,19 @@ def _sgmv_expand(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
add_inputs (bool, optional): Defaults to False. adds the final lora
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
......@@ -134,6 +137,7 @@ def _sgmv_expand(
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
......
......@@ -112,6 +112,7 @@ def _sgmv_expand_slice(
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
......@@ -124,20 +125,22 @@ def _sgmv_expand_slice(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
max_seq_length (int): The max sequence lengths of the sequences
in the batch
slice_offst (int): output_tensor's offst
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output..
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
......@@ -145,6 +148,7 @@ def _sgmv_expand_slice(
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
......
......@@ -110,6 +110,7 @@ def _sgmv_shrink(
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
"""
......@@ -120,17 +121,19 @@ def _sgmv_shrink(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
scaling (float): Scaling factor.
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
......@@ -138,6 +141,7 @@ def _sgmv_shrink(
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment