Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
import asyncio
import time
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
from typing import Sequence as GenericSequence
from typing import Union
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
......@@ -22,14 +23,15 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
from vllm.inputs import PromptInputs
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.utils import random_uuid
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid
logger = init_logger(__name__)
......@@ -65,9 +67,9 @@ class OpenAIServingChat(OpenAIServing):
async def create_chat_completion(
self,
request: ChatCompletionRequest,
raw_request: Optional[Request] = None
) -> Union[ErrorResponse, AsyncGenerator[str, None],
ChatCompletionResponse]:
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
......@@ -98,16 +100,15 @@ class OpenAIServingChat(OpenAIServing):
tool.model_dump() for tool in request.tools
]
prompt = tokenizer.apply_chat_template(
prompt = apply_chat_template(
tokenizer,
conversation=conversation,
tokenize=False,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
assert isinstance(prompt, str)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
......@@ -149,9 +150,8 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
engine_inputs: PromptInputs = {
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
}
engine_inputs = TokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
......@@ -176,18 +176,20 @@ class OpenAIServingChat(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
if raw_request:
result_generator = iterate_with_cancellation(
result_generator, raw_request.is_disconnected)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation, tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
......@@ -201,11 +203,11 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
# Send response for each token for each request.n (index)
......@@ -422,23 +424,22 @@ class OpenAIServingChat(OpenAIServing):
async def chat_completion_full_generator(
self,
request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0]
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
async for res in result_generator:
if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
assert final_res is not None
choices: List[ChatCompletionResponseChoice] = []
......@@ -504,13 +505,14 @@ class OpenAIServingChat(OpenAIServing):
model=model_name,
choices=choices,
usage=usage,
prompt_logprobs=final_res.prompt_logprobs,
)
return response
def _get_top_logprobs(
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1],
......@@ -528,12 +530,11 @@ class OpenAIServingChat(OpenAIServing):
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
logprobs_content = []
logprobs_content: List[ChatCompletionLogProbsContent] = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
......@@ -541,23 +542,32 @@ class OpenAIServingChat(OpenAIServing):
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
bytes=list(token.encode("utf-8", errors="replace"))))
bytes=list(token.encode("utf-8", errors="replace")),
))
else:
step_token = step_top_logprobs[token_id]
step_decoded = step_token.decoded_token
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self._get_decoded_token(
step_top_logprobs[token_id], token_id, tokenizer,
self.return_tokens_as_token_ids),
logprob=max(step_top_logprobs[token_id].logprob,
-9999.0),
bytes=list(
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
step_token,
token_id,
tokenizer,
self.return_tokens_as_token_ids,
),
logprob=max(step_token.logprob, -9999.0),
bytes=None if step_decoded is None else list(
step_decoded.encode("utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs,
tokenizer)))
step_top_logprobs,
num_output_top_logprobs,
tokenizer,
),
))
return ChatCompletionLogProbs(content=logprobs_content)
import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
from typing import Sequence as GenericSequence
from typing import Tuple, cast
from typing import Tuple, Union, cast
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
......@@ -18,7 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo)
ErrorResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
......@@ -28,6 +28,7 @@ from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
......@@ -59,8 +60,11 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
async def create_completion(
self,
request: CompletionRequest,
raw_request: Request,
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
......@@ -84,7 +88,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time = int(time.time())
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
(
lora_request,
......@@ -143,8 +147,8 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(*generators)
result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
......@@ -156,7 +160,6 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
if stream:
return self.completion_stream_generator(request,
raw_request,
result_generator,
request_id,
created_time,
......@@ -168,10 +171,6 @@ class OpenAIServingCompletion(OpenAIServing):
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
......@@ -194,6 +193,8 @@ class OpenAIServingCompletion(OpenAIServing):
model_name,
tokenizer,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
......@@ -214,13 +215,12 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
raw_request: Request,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts
......@@ -229,12 +229,13 @@ class OpenAIServingCompletion(OpenAIServing):
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await self.async_engine_client.abort(
f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]]
for output in res.outputs:
i = output.index + prompt_idx * num_choices
......@@ -243,19 +244,25 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_text is not None
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
out_logprobs = res.prompt_logprobs
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = (res.prompt_token_ids +
output.token_ids)
out_logprobs = res.prompt_logprobs + (output.logprobs
or [])
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids, *output.token_ids
]
out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
has_echoed[i] = True
else:
# return just the delta
......@@ -300,7 +307,7 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = len(res.prompt_token_ids)
prompt_tokens = len(prompt_token_ids)
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
......@@ -341,7 +348,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str,
created_time: int,
model_name: str,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
......@@ -352,16 +359,31 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
Logprob]]]]
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_text is not None
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + list(output.token_ids)
out_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs is not None else None)
assert prompt_text is not None
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
......@@ -385,6 +407,7 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
)
choices.append(choice_data)
......@@ -411,7 +434,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
......@@ -428,17 +451,21 @@ class OpenAIServingCompletion(OpenAIServing):
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token(
step_top_logprobs[token_id],
step_token,
token_id,
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
return_as_token_id=self.return_tokens_as_token_ids,
)
token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
......
import asyncio
import base64
import time
from typing import AsyncIterator, List, Optional, Tuple, cast
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
import numpy as np
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData, UsageInfo)
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
......@@ -21,18 +24,28 @@ logger = init_logger(__name__)
TypeTokenIDs = List[int]
def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
return output.embedding
elif encoding_format == "base64":
embedding_bytes = np.array(output.embedding).tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)
def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
created_time: int, model_name: str,
encoding_format: str) -> EmbeddingResponse:
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding
if encoding_format == "base64":
embedding_bytes = np.array(embedding).tobytes()
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
embedding = _get_embedding(final_res.outputs, encoding_format)
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)
......@@ -68,21 +81,25 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self._check_embedding_mode(model_config.embedding_mode)
self._enabled = self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(self, request: EmbeddingRequest,
raw_request: Request):
async def create_embedding(
self,
request: EmbeddingRequest,
raw_request: Optional[Request] = None,
) -> Union[EmbeddingResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
if not self._enabled:
return self.create_error_response("Embedding API disabled")
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
encoding_format = (request.encoding_format
if request.encoding_format else "float")
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
......@@ -92,7 +109,7 @@ class OpenAIServingEmbedding(OpenAIServing):
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try:
(
lora_request,
......@@ -137,18 +154,16 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
result_generator = merge_async_iterators(
*generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for final_res in final_res_batch:
......@@ -160,15 +175,18 @@ class OpenAIServingEmbedding(OpenAIServing):
response = request_output_to_embedding_response(
final_res_batch_checked, request_id, created_time, model_name,
encoding_format)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def _check_embedding_mode(self, embedding_mode: bool):
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
if not embedding_mode:
logger.warning(
"embedding_mode is False. Embedding API will not work.")
else:
logger.info("Activating the server engine with embedding enabled.")
return embedding_mode
......@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeCompletionRequest,
TokenizeRequest)
# yapf: enable
from vllm.inputs import parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
......@@ -31,7 +31,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
......
......@@ -2,7 +2,9 @@ from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
from vllm.entrypoints.chat_utils import (apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
......@@ -70,12 +72,12 @@ class OpenAIServingTokenization(OpenAIServing):
logger.warning(
"Multi-modal inputs are ignored during tokenization")
prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
prompt = apply_chat_template(
tokenizer,
conversation=conversation,
tokenize=False,
chat_template=self.chat_template)
assert isinstance(prompt, str)
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
else:
prompt = request.prompt
......
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None
VLLM_RPC_PORT: int = 5570
VLLM_RPC_BASE_PATH: str = tempfile.gettempdir()
VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None
......@@ -32,6 +33,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
......@@ -46,14 +48,21 @@ if TYPE_CHECKING:
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_AUDIO_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False
VLLM_NO_DEPRECATION_WARNING: bool = False
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
def get_default_cache_root():
......@@ -132,7 +141,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
os.path.join(get_default_cache_root(), "vllm"),
)),
# used in distributed environment to determine the master address
# used in distributed environment to determine the ip address
# of the current node, when the node has multiple network interfaces.
# If you are using multi-node inference, you should set this differently
# on each node.
'VLLM_HOST_IP':
lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""),
......@@ -145,10 +157,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None,
# used when the frontend api server is running in multi-processing mode,
# to communicate with the backend engine process over ZMQ.
'VLLM_RPC_PORT':
lambda: int(os.getenv('VLLM_PORT', '5570')),
# path used for ipc when the frontend api server is running in
# multi-processing mode to communicate with the backend engine process.
'VLLM_RPC_BASE_PATH':
lambda: os.getenv('VLLM_RPC_BASE_PATH', tempfile.gettempdir()),
# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
......@@ -268,6 +280,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
......@@ -336,12 +352,17 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Timeout for fetching audio when serving multimodal models
# Default is 5 seconds
"VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "5")),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH":
lambda: os.path.expanduser(
os.getenv(
"VLLM_ASSETS_CACHE",
"VLLM_XLA_CACHE_PATH",
os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
)),
"VLLM_FUSED_MOE_CHUNK_SIZE":
......@@ -351,6 +372,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_NO_DEPRECATION_WARNING":
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),
# If set, the OpenAI API server will stay alive even after the underlying
# AsyncLLMEngine errors and stops serving requests
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH":
lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)),
# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
# the user to specify a max sequence length greater than
# the max length derived from the model's config.json.
......@@ -359,6 +385,39 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda:
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
("1", "true")),
# If set, forces FP8 Marlin to be used for FP8 quantization regardless
# of the hardware support for FP8 compute.
"VLLM_TEST_FORCE_FP8_MARLIN":
lambda:
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
# If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed.
# See https://github.com/vllm-project/vllm/issues/7045
"VLLM_ALLOW_ENGINE_USE_RAY":
lambda:
(os.environ.get("VLLM_ALLOW_ENGINE_USE_RAY", "0").strip().lower() in
("1", "true")),
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded
"VLLM_PLUGINS":
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
"VLLM_PLUGINS"].split(","),
# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
}
# end-env-vars-definition
......
......@@ -13,7 +13,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_open_port,
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async)
from vllm.worker.worker_base import WorkerWrapperBase
......@@ -141,7 +141,6 @@ class CPUExecutor(ExecutorBase):
rank=rank,
distributed_init_method=self.distributed_init_method,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=rank == 0,
......@@ -332,7 +331,6 @@ def _verify_and_get_scheduler_config(
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
_GB = 1 << 30
if config.enable_prefix_caching:
logger.warning("Prefix caching is not supported on CPU, disable it.")
config.enable_prefix_caching = False
......@@ -341,11 +339,11 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
if kv_cache_space >= 0:
if kv_cache_space == 0:
config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore
config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
"for CPU backend is not set, using 4 by default.")
else:
config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore
config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
else:
raise RuntimeError(
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
......
......@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.lora.request import LoRARequest
......@@ -29,9 +29,9 @@ class ExecutorBase(ABC):
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
......@@ -40,10 +40,9 @@ class ExecutorBase(ABC):
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.multimodal_config = multimodal_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
self._init_executor()
@abstractmethod
......
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
......@@ -7,15 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
logger = init_logger(__name__)
def create_worker(worker_module_name, worker_class_name, **kwargs):
def create_worker(worker_module_name: str, worker_class_name: str,
worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
**kwargs):
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
)
wrapper.init_worker(**kwargs)
return wrapper.worker
......@@ -55,13 +58,27 @@ class GPUExecutor(ExecutorBase):
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
observability_config=self.observability_config,
)
def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_class_fn = None
if self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_worker"
worker_class_name = "MultiStepWorker"
elif self.speculative_config:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
return (worker_module_name, worker_class_name, worker_class_fn)
def _get_create_worker_kwargs(
self,
local_rank: int = 0,
......@@ -69,13 +86,15 @@ class GPUExecutor(ExecutorBase):
distributed_init_method: Optional[str] = None) -> Dict:
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)
if self.speculative_config is None:
worker_kwargs.update(worker_module_name="vllm.worker.worker",
worker_class_name="Worker")
else:
worker_kwargs.update(
worker_module_name="vllm.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker")
(worker_module_name, worker_class_name,
worker_class_fn) = self._get_worker_module_and_class()
worker_kwargs.update(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
)
return worker_kwargs
def _create_worker(self,
......
from array import array
from typing import Any, Type
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if isinstance(obj, array):
assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}.")
return obj.tobytes()
def decode_hook(type: Type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if type is array:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj)
return deserialized
......@@ -100,9 +100,8 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(
self.driver_worker.execute_model
)(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output
async def check_health_async(self) -> None:
......
......@@ -10,8 +10,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
get_open_port, make_async)
logger = init_logger(__name__)
......@@ -49,7 +49,6 @@ class OpenVINOExecutor(ExecutorBase):
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
......@@ -165,14 +164,13 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0:
_GB = 1 << 30
if kv_cache_space == 0:
config.openvino_kvcache_space_bytes = 4 * _GB # type: ignore
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
"for OpenVINO backend is not set, using 4 by default.")
else:
config.openvino_kvcache_space_bytes = kv_cache_space * _GB # type: ignore
config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
else:
raise RuntimeError(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
......
......@@ -4,9 +4,12 @@ from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import msgspec
import vllm.envs as envs
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
......@@ -60,6 +63,18 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
......@@ -76,19 +91,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
return ray_remote_kwargs
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
(worker_module_name, worker_class_name,
worker_class_fn) = self._get_worker_module_and_class()
return dict(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
trust_remote_code=self.model_config.trust_remote_code,
)
# child class could overwrite this to return actual env vars.
def _get_env_vars_to_be_updated(self):
return self._env_vars_for_all_workers
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if (self.parallel_config.tensor_parallel_size == 1
......@@ -115,9 +131,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
driver_ip = get_ip()
logger.info("driver_ip: %s", driver_ip)
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
......@@ -202,6 +218,19 @@ class RayGPUExecutor(DistributedGPUExecutor):
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node.")
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
......@@ -213,8 +242,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
all_args=self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
......@@ -297,8 +330,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
outputs = ray.get(self.forward_dag.execute(execute_model_req))
return outputs[0]
serialized_data = self.input_encoder.encode(execute_model_req)
outputs = ray.get(self.forward_dag.execute(serialized_data))
output = self.output_decoder.decode(outputs[0])
return output
def _run_workers(
self,
......@@ -446,11 +481,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.shutdown()
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
......@@ -472,9 +503,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
dag_future = await self.forward_dag.execute_async(execute_model_req)
serialized_data = self.input_encoder.encode(execute_model_req)
dag_future = await self.forward_dag.execute_async(serialized_data)
outputs = await dag_future
return outputs[0]
return self.output_decoder.decode(outputs[0])
async def _driver_execute_model_async(
self,
......@@ -523,8 +555,4 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
return await asyncio.gather(*coros)
def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.shutdown()
from typing import List, Optional, Tuple, Union
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import msgspec
from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
from vllm.utils import get_ip, is_hip, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
PG_WAIT_TIMEOUT = 1800
try:
import ray
from ray._private.state import available_resources_per_node
from ray.util import placement_group_table
from ray.util.placement_group import PlacementGroup
class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
......@@ -23,6 +33,10 @@ try:
# that thread.
self.compiled_dag_cuda_device_set = False
self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
def get_node_ip(self) -> str:
return get_ip()
......@@ -32,16 +46,26 @@ try:
return node_id, gpu_ids
def execute_model_spmd(
self, req_or_tuple: Union[ExecuteModelRequest,
Tuple[ExecuteModelRequest,
IntermediateTensors]]):
self, req_or_tuple: Union[bytes,
Tuple[bytes,
Optional[IntermediateTensors]]]
) -> bytes:
"""Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled.
Args:
req_or_tuple: The request to execute the model, or a tuple
containing the request and intermediate tensors.
req_or_tuple: A request or a tuple containing the
request and intermediate tensors. Intermediate tensors are
None unless if it is provided because it is > 0 pipeline
stage. The request is serialized by msgspec.
"""
if isinstance(req_or_tuple, bytes):
serialized_req, intermediate_tensors = req_or_tuple, None
else:
serialized_req, intermediate_tensors = req_or_tuple
execute_model_req = self.input_decoder.decode(serialized_req)
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# device.
......@@ -50,16 +74,14 @@ try:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
if isinstance(req_or_tuple, tuple):
execute_model_req, intermediate_tensors = req_or_tuple
else:
execute_model_req = req_or_tuple
intermediate_tensors = None
output = self.worker._execute_model_spmd(execute_model_req,
intermediate_tensors)
# Pipeline model request and output to the next pipeline stage.
if isinstance(output, IntermediateTensors):
return execute_model_req, output
output = serialized_req, output
else:
output = self.output_encoder.encode(output)
return output
ray_import_err = None
......@@ -82,6 +104,106 @@ def assert_ray_available():
"`pip install ray`.") from ray_import_err
def _verify_bundles(placement_group: "PlacementGroup",
parallel_config: ParallelConfig, device_str: str):
"""Verify a given placement group has bundles located in the right place.
There are 2 rules.
- Warn if all tensor parallel workers cannot fit in a single node.
- Fail if driver node is not included in a placement group.
"""
assert ray.is_initialized(), (
"Ray is not initialized although distributed-executor-backend is ray.")
pg_data = placement_group_table(placement_group)
# bundle_idx -> node_id
bundle_to_node_ids = pg_data["bundles_to_node_id"]
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles = pg_data["bundles"]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
for bundle_idx, node_id in bundle_to_node_ids.items():
node_id_to_bundle[node_id].append(bundles[bundle_idx])
driver_node_id = ray.get_runtime_context().get_node_id()
if driver_node_id not in node_id_to_bundle:
raise RuntimeError(
f"driver node id {driver_node_id} is not included in a placement "
f"group {placement_group.id}. Node id -> bundles "
f"{node_id_to_bundle}. "
"You don't have enough GPUs available in a current node. Check "
"`ray status` to see if you have available GPUs in a node "
f"{driver_node_id} before starting an vLLM engine.")
for node_id, bundles in node_id_to_bundle.items():
if len(bundles) < parallel_config.tensor_parallel_size:
logger.warning(
"tensor_parallel_size=%d "
"is bigger than a reserved number of %ss (%d "
"%ss) in a node %s. Tensor parallel workers can be "
"spread out to 2+ nodes which can degrade the performance "
"unless you have fast interconnect across nodes, like "
"Infiniband. To resolve this issue, make sure you have more "
"than %d GPUs available at each node.",
parallel_config.tensor_parallel_size, device_str, len(bundles),
device_str, node_id, parallel_config.tensor_parallel_size)
def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
"""Wait until a placement group is ready.
It prints the informative log messages if the placement group is
not created within time.
"""
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
placement_group_specs = current_placement_group.bundle_specs
s = time.time()
pg_ready_ref = current_placement_group.ready()
wait_interval = 10
while time.time() - s < PG_WAIT_TIMEOUT:
ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
if len(ready) > 0:
break
# Exponential backoff for warning print.
wait_interval *= 2
logger.info(
"Waiting for creating a placement group of specs for "
"%d seconds. specs=%s. Check "
"`ray status` to see if you have enough resources.",
int(time.time() - s), placement_group_specs)
try:
ray.get(pg_ready_ref, timeout=0)
except ray.exceptions.GetTimeoutError:
raise ValueError(
"Cannot provide a placement group of "
f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
"`ray status` to make sure the cluster has enough resources."
) from None
def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
ray.util.remove_placement_group(current_placement_group)
s = time.time()
wait_interval = 10
while time.time() - s < PG_WAIT_TIMEOUT:
pg = ray.util.get_current_placement_group()
if pg is None:
break
# Exponential backoff for warning print.
wait_interval *= 2
logger.info(
"Waiting for removing a placement group of specs for "
"%d seconds.", int(time.time() - s))
time.sleep(wait_interval)
def initialize_ray_cluster(
parallel_config: ParallelConfig,
ray_address: Optional[str] = None,
......@@ -111,7 +233,7 @@ def initialize_ray_cluster(
# Placement group is already set.
return
device_str = "GPU" if not is_tpu() else "TPU"
device_str = "GPU" if not current_platform.is_tpu() else "TPU"
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
......@@ -140,15 +262,32 @@ def initialize_ray_cluster(
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group.")
# Create a new placement group
placement_group_specs = ([{
device_str: 1
}] * parallel_config.world_size)
placement_group_specs: List[Dict[str, float]] = ([{
device_str: 1.0
} for _ in range(parallel_config.world_size)])
# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
# the current node has at least one device.
current_ip = get_ip()
current_node_id = ray.get_runtime_context().get_node_id()
current_node_resource = available_resources_per_node()[current_node_id]
if current_node_resource.get(device_str, 0) < 1:
raise ValueError(
f"Current node has no {device_str} available. "
f"{current_node_resource=}. vLLM engine cannot start without "
f"{device_str}. Make sure you have at least 1 {device_str} "
f"available in a node {current_node_id=} {current_ip=}.")
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs[0][f"node:{current_ip}"] = 0.001
# By default, Ray packs resources as much as possible.
current_placement_group = ray.util.placement_group(
placement_group_specs)
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)
placement_group_specs, strategy="PACK")
_wait_until_pg_ready(current_placement_group)
assert current_placement_group is not None
_verify_bundles(current_placement_group, parallel_config, device_str)
# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group
import asyncio
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
Tuple, Union)
from typing import List, Optional
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
from vllm.executor.xpu_executor import XPUExecutor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
from vllm.utils import get_vllm_instance_id, make_async
logger = init_logger(__name__)
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayXPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
assert device_config.device_type == "xpu"
assert (not speculative_config
), "Speculative decoding not yet supported for XPU backend"
self.model_config = model_config
self.cache_config = cache_config
self.load_config = load_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
def _init_executor(self) -> None:
pass
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- Tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", )
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_gpu_blocks, num_cpu_blocks
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
return dict(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
# For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization
else:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus = 1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
def _get_env_vars_to_be_updated(self):
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
node_workers = defaultdict(list)
node_gpus = defaultdict(list)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
# TODO: add env var for xpu
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
def collect_arg_helper_func(**kwargs):
# avoid writing `{"name": value}` manually
return kwargs
init_worker_all_kwargs = []
# Initialize the actual workers inside worker wrapper.
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
local_rank = node_workers[node_id].index(rank)
init_worker_all_kwargs.append(
collect_arg_helper_func(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
is_driver_worker=rank == 0,
))
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
VLLM_INSTANCE_ID = get_vllm_instance_id()
self._run_workers("init_device")
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (_, _) in worker_node_and_gpu_ids]
return all_args_to_update_environment_variables
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info("# GPU blocks: %d, "
"# CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
def _run_workers(
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
ways:
- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs
driver_worker_output = []
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
else:
assert self.driver_dummy_worker is not None
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self, enable_asyncio: bool):
import pkg_resources
from packaging import version
required_version = version.parse("2.32")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.
bind( # type: ignore[attr-defined]
input_data) for worker in self.workers
])
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def _check_if_any_actor_is_dead(self):
if not self.workers:
return
dead_actors = []
for actor in self.workers:
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
if actor_state["State"] == "DEAD":
dead_actors.append(actor)
if dead_actors:
raise RuntimeError("At least one Worker is dead. "
f"Dead Workers: {dead_actors}. ")
class RayXPUExecutorAsync(RayXPUExecutor, DistributedGPUExecutorAsync):
class RayXPUExecutorAsync(RayXPUExecutor, RayGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_method = make_async(self.driver_worker.execute_method)
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req)
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)
self.pp_locks: Optional[List[asyncio.Lock]] = None
......@@ -52,7 +52,6 @@ class TPUExecutor(ExecutorBase):
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
multimodal_config=self.multimodal_config,
is_driver_worker=rank == 0,
)
......
from typing import List, Optional
from typing import Callable, List, Optional, Tuple, Type, Union
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerWrapperBase
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__)
......@@ -29,9 +29,9 @@ class XPUExecutor(GPUExecutor):
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
assert device_config.device_type == "xpu"
assert (not speculative_config
......@@ -46,35 +46,27 @@ class XPUExecutor(GPUExecutor):
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None
self.observability_config = observability_config
# Instantiate the worker and load the model to GPU.
self._init_executor()
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
else:
def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_class_fn = None
if self.speculative_config is not None:
raise NotImplementedError(
"XPU does not support speculative decoding")
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
else:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
return (worker_module_name, worker_class_name, worker_class_fn)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output
......
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
TextPrompt, TokensPrompt, parse_and_batch_prompt)
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
......@@ -12,7 +14,17 @@ See also:
"""
__all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
"TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
"InputContext", "InputRegistry"
"TextPrompt",
"TokensPrompt",
"PromptInputs",
"SingletonPromptInputs",
"ExplicitEncoderDecoderPrompt",
"LLMInputs",
"EncoderDecoderLLMInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"INPUT_REGISTRY",
"InputContext",
"InputRegistry",
]
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
TypedDict, Union, cast, overload)
from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple,
Union)
from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
......@@ -92,12 +33,71 @@ class TokensPrompt(TypedDict):
"""
PromptInputs = Union[str, TextPrompt, TokensPrompt]
SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
"""
The inputs to the LLM, which can take one of the following forms:
Set of possible schemas for a single LLM input:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPromptInputs` may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
_T1_co = TypeVar("_T1_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
_T2_co = TypeVar("_T2_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
"""Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
:class:`SingletonPromptInputs` schemas, and are not
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
:class:`SingletonPromptInputs` instances.
"""
encoder_prompt: _T1_co
decoder_prompt: Optional[_T2_co]
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- A single data structure containing both an encoder and a decoder prompt
(:class:`ExplicitEncoderDecoderPrompt`)
"""
......@@ -105,6 +105,8 @@ class LLMInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
......@@ -119,3 +121,58 @@ class LLMInputs(TypedDict):
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class EncoderDecoderLLMInputs(LLMInputs):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder_prompt_token_ids: List[int]
"""The token IDs of the encoder prompt."""
encoder_prompt: NotRequired[Optional[str]]
"""
The original encoder prompt text corresponding to the token IDs, if
available.
"""
_T1 = TypeVar("_T1",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
_T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
def build_explicit_enc_dec_prompt(
encoder_prompt: _T1,
decoder_prompt: Optional[_T2],
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1],
dec_prompts: Iterable[Optional[_T2]],
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Zip encoder and decoder prompts together into a list of
:class:`ExplicitEncoderDecoderPrompt` instances.
"""
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
) -> List[Tuple[_T1, Optional[_T2]]]:
return [(enc_dec_prompt["encoder_prompt"],
enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts]
from typing import List, Literal, Sequence, TypedDict, Union, overload
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs)
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt, str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False) for elem in prompt
]
if is_list_of(prompt, int):
# case 3: array of tokens
return [ParsedTokens(content=prompt, is_tokens=True)]
if is_list_of(prompt, list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt[0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in prompt
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
def is_explicit_encoder_decoder_prompt(
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(inputs, dict) and "encoder_prompt" in inputs
def is_valid_encoder_decoder_llm_inputs(
inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
) -> TypeIs[EncoderDecoderLLMInputs]:
return "encoder_prompt_token_ids" in inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment