Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
import asyncio
import time import time
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import Union
from fastapi import Request from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_chat_template,
load_chat_template, load_chat_template,
parse_chat_messages) parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
...@@ -22,14 +23,15 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -22,14 +23,15 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing, OpenAIServing,
PromptAdapterPath) PromptAdapterPath)
from vllm.inputs import PromptInputs from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.utils import random_uuid from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -65,9 +67,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -65,9 +67,9 @@ class OpenAIServingChat(OpenAIServing):
async def create_chat_completion( async def create_chat_completion(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
raw_request: Optional[Request] = None raw_request: Optional[Request] = None,
) -> Union[ErrorResponse, AsyncGenerator[str, None], ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ChatCompletionResponse]: ErrorResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create See https://platform.openai.com/docs/api-reference/chat/create
...@@ -98,16 +100,15 @@ class OpenAIServingChat(OpenAIServing): ...@@ -98,16 +100,15 @@ class OpenAIServingChat(OpenAIServing):
tool.model_dump() for tool in request.tools tool.model_dump() for tool in request.tools
] ]
prompt = tokenizer.apply_chat_template( prompt = apply_chat_template(
tokenizer,
conversation=conversation, conversation=conversation,
tokenize=False, chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts, tools=tool_dicts,
documents=request.documents, documents=request.documents,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}), **(request.chat_template_kwargs or {}),
) )
assert isinstance(prompt, str)
except Exception as e: except Exception as e:
logger.error("Error in applying chat template from request: %s", e) logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -149,9 +150,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -149,9 +150,8 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
engine_inputs: PromptInputs = { engine_inputs = TokensPrompt(
"prompt_token_ids": prompt_inputs["prompt_token_ids"], prompt_token_ids=prompt_inputs["prompt_token_ids"])
}
if mm_data is not None: if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data engine_inputs["multi_modal_data"] = mm_data
...@@ -176,18 +176,20 @@ class OpenAIServingChat(OpenAIServing): ...@@ -176,18 +176,20 @@ class OpenAIServingChat(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
if raw_request:
result_generator = iterate_with_cancellation(
result_generator, raw_request.is_disconnected)
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer) request, result_generator, request_id, conversation, tokenizer)
else: try:
try: return await self.chat_completion_full_generator(
return await self.chat_completion_full_generator( request, result_generator, request_id, conversation, tokenizer)
request, raw_request, result_generator, request_id, except ValueError as e:
conversation, tokenizer) # TODO: Use a vllm-specific Validation Error
except ValueError as e: return self.create_error_response(str(e))
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str: def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt: if request.add_generation_prompt:
...@@ -201,11 +203,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -201,11 +203,11 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput],
request_id: str, request_id: str,
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True first_iteration = True
# Send response for each token for each request.n (index) # Send response for each token for each request.n (index)
...@@ -422,23 +424,22 @@ class OpenAIServingChat(OpenAIServing): ...@@ -422,23 +424,22 @@ class OpenAIServingChat(OpenAIServing):
async def chat_completion_full_generator( async def chat_completion_full_generator(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput],
request_id: str, request_id: str,
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]: ) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
final_res: Optional[RequestOutput] = None final_res: Optional[RequestOutput] = None
async for res in result_generator: try:
if raw_request is not None and await raw_request.is_disconnected(): async for res in result_generator:
# Abort the request if the client disconnects. final_res = res
await self.async_engine_client.abort(request_id) except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res = res
assert final_res is not None assert final_res is not None
choices: List[ChatCompletionResponseChoice] = [] choices: List[ChatCompletionResponseChoice] = []
...@@ -504,13 +505,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -504,13 +505,14 @@ class OpenAIServingChat(OpenAIServing):
model=model_name, model=model_name,
choices=choices, choices=choices,
usage=usage, usage=usage,
prompt_logprobs=final_res.prompt_logprobs,
) )
return response return response
def _get_top_logprobs( def _get_top_logprobs(
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]: tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
return [ return [
ChatCompletionLogProb(token=(token := self._get_decoded_token( ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1], p[1],
...@@ -528,12 +530,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -528,12 +530,11 @@ class OpenAIServingChat(OpenAIServing):
self, self,
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
num_output_top_logprobs: Optional[int] = None, num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs: ) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs.""" """Create OpenAI-style logprobs."""
logprobs_content: List[ChatCompletionLogProbsContent] = []
logprobs_content = []
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
...@@ -541,23 +542,32 @@ class OpenAIServingChat(OpenAIServing): ...@@ -541,23 +542,32 @@ class OpenAIServingChat(OpenAIServing):
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids: if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
token=token, token=token,
bytes=list(token.encode("utf-8", errors="replace")))) bytes=list(token.encode("utf-8", errors="replace")),
))
else: else:
step_token = step_top_logprobs[token_id]
step_decoded = step_token.decoded_token
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
token=self._get_decoded_token( token=self._get_decoded_token(
step_top_logprobs[token_id], token_id, tokenizer, step_token,
self.return_tokens_as_token_ids), token_id,
logprob=max(step_top_logprobs[token_id].logprob, tokenizer,
-9999.0), self.return_tokens_as_token_ids,
bytes=list( ),
step_top_logprobs[token_id].decoded_token.encode( logprob=max(step_token.logprob, -9999.0),
"utf-8", errors="replace")), bytes=None if step_decoded is None else list(
step_decoded.encode("utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs( top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs, step_top_logprobs,
tokenizer))) num_output_top_logprobs,
tokenizer,
),
))
return ChatCompletionLogProbs(content=logprobs_content) return ChatCompletionLogProbs(content=logprobs_content)
import asyncio
import time import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional) Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple, cast from typing import Tuple, Union, cast
from fastapi import Request from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
...@@ -18,7 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, ...@@ -18,7 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
UsageInfo) ErrorResponse, UsageInfo)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing, OpenAIServing,
...@@ -28,6 +28,7 @@ from vllm.outputs import RequestOutput ...@@ -28,6 +28,7 @@ from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -59,8 +60,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -59,8 +60,11 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids) return_tokens_as_token_ids=return_tokens_as_token_ids)
async def create_completion(self, request: CompletionRequest, async def create_completion(
raw_request: Request): self,
request: CompletionRequest,
raw_request: Request,
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create See https://platform.openai.com/docs/api-reference/completions/create
...@@ -84,7 +88,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -84,7 +88,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time = int(time.time()) created_time = int(time.time())
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = [] generators: List[AsyncGenerator[RequestOutput, None]] = []
try: try:
( (
lora_request, lora_request,
...@@ -143,8 +147,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -143,8 +147,8 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[ result_generator = merge_async_iterators(
int, RequestOutput]] = merge_async_iterators(*generators) *generators, is_cancelled=raw_request.is_disconnected)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use # results. In addition, we do not stream the results when use
...@@ -156,7 +160,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -156,7 +160,6 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response # Streaming response
if stream: if stream:
return self.completion_stream_generator(request, return self.completion_stream_generator(request,
raw_request,
result_generator, result_generator,
request_id, request_id,
created_time, created_time,
...@@ -168,10 +171,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -168,10 +171,6 @@ class OpenAIServingCompletion(OpenAIServing):
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try: try:
async for i, res in result_generator: async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch): for i, final_res in enumerate(final_res_batch):
...@@ -194,6 +193,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -194,6 +193,8 @@ class OpenAIServingCompletion(OpenAIServing):
model_name, model_name,
tokenizer, tokenizer,
) )
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -214,13 +215,12 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -214,13 +215,12 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator( async def completion_stream_generator(
self, self,
request: CompletionRequest, request: CompletionRequest,
raw_request: Request,
result_generator: AsyncIterator[Tuple[int, RequestOutput]], result_generator: AsyncIterator[Tuple[int, RequestOutput]],
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
num_prompts: int, num_prompts: int,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts previous_texts = [""] * num_choices * num_prompts
...@@ -229,12 +229,13 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -229,12 +229,13 @@ class OpenAIServingCompletion(OpenAIServing):
try: try:
async for prompt_idx, res in result_generator: async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt
# Abort the request if the client disconnects. delta_token_ids: GenericSequence[int]
if await raw_request.is_disconnected(): out_logprobs: Optional[GenericSequence[Optional[Dict[
await self.async_engine_client.abort( int, Logprob]]]]
f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs: for output in res.outputs:
i = output.index + prompt_idx * num_choices i = output.index + prompt_idx * num_choices
...@@ -243,19 +244,25 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -243,19 +244,25 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
assert prompt_text is not None
# only return the prompt # only return the prompt
delta_text = res.prompt delta_text = prompt_text
delta_token_ids = res.prompt_token_ids delta_token_ids = prompt_token_ids
out_logprobs = res.prompt_logprobs out_logprobs = prompt_logprobs
has_echoed[i] = True has_echoed[i] = True
elif (request.echo and request.max_tokens > 0 elif (request.echo and request.max_tokens > 0
and not has_echoed[i]): and not has_echoed[i]):
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token # echo the prompt and first token
delta_text = res.prompt + output.text delta_text = prompt_text + output.text
delta_token_ids = (res.prompt_token_ids + delta_token_ids = [
output.token_ids) *prompt_token_ids, *output.token_ids
out_logprobs = res.prompt_logprobs + (output.logprobs ]
or []) out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
has_echoed[i] = True has_echoed[i] = True
else: else:
# return just the delta # return just the delta
...@@ -300,7 +307,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -300,7 +307,7 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage): and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None): or output.finish_reason is not None):
prompt_tokens = len(res.prompt_token_ids) prompt_tokens = len(prompt_token_ids)
completion_tokens = len(output.token_ids) completion_tokens = len(output.token_ids)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
...@@ -341,7 +348,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -341,7 +348,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> CompletionResponse: ) -> CompletionResponse:
choices: List[CompletionResponseChoice] = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
...@@ -352,16 +359,31 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -352,16 +359,31 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = final_res.prompt_logprobs prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
Logprob]]]]
for output in final_res.outputs: for output in final_res.outputs:
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
assert prompt_text is not None
token_ids = prompt_token_ids token_ids = prompt_token_ids
out_logprobs = prompt_logprobs out_logprobs = prompt_logprobs
output_text = prompt_text output_text = prompt_text
elif request.echo and request.max_tokens > 0: elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + list(output.token_ids) assert prompt_text is not None
out_logprobs = (prompt_logprobs + output.logprobs token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is not None else None)
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text output_text = prompt_text + output.text
else: else:
token_ids = output.token_ids token_ids = output.token_ids
...@@ -385,6 +407,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -385,6 +407,7 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
stop_reason=output.stop_reason, stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
) )
choices.append(choice_data) choices.append(choice_data)
...@@ -411,7 +434,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -411,7 +434,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int, num_output_top_logprobs: int,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
initial_text_offset: int = 0, initial_text_offset: int = 0,
) -> CompletionLogProbs: ) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API.""" """Create logprobs for OpenAI Completion API."""
...@@ -428,17 +451,21 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -428,17 +451,21 @@ class OpenAIServingCompletion(OpenAIServing):
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids: if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
out_tokens.append(token) out_tokens.append(token)
out_token_logprobs.append(None) out_token_logprobs.append(None)
out_top_logprobs.append(None) out_top_logprobs.append(None)
else: else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token( token = self._get_decoded_token(
step_top_logprobs[token_id], step_token,
token_id, token_id,
tokenizer, tokenizer,
return_as_token_id=self.return_tokens_as_token_ids) return_as_token_id=self.return_tokens_as_token_ids,
token_logprob = max(step_top_logprobs[token_id].logprob, )
-9999.0) token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token) out_tokens.append(token)
out_token_logprobs.append(token_logprob) out_token_logprobs.append(token_logprob)
......
import asyncio
import base64 import base64
import time import time
from typing import AsyncIterator, List, Optional, Tuple, cast from typing import AsyncGenerator, List, Literal, Optional, Union, cast
import numpy as np import numpy as np
from fastapi import Request from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest, from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
EmbeddingResponseData, UsageInfo) EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -21,18 +24,28 @@ logger = init_logger(__name__) ...@@ -21,18 +24,28 @@ logger = init_logger(__name__)
TypeTokenIDs = List[int] TypeTokenIDs = List[int]
def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
return output.embedding
elif encoding_format == "base64":
embedding_bytes = np.array(output.embedding).tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)
def request_output_to_embedding_response( def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str, final_res_batch: List[EmbeddingRequestOutput], request_id: str,
created_time: int, model_name: str, created_time: int, model_name: str,
encoding_format: str) -> EmbeddingResponse: encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = [] data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch): for idx, final_res in enumerate(final_res_batch):
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding embedding = _get_embedding(final_res.outputs, encoding_format)
if encoding_format == "base64":
embedding_bytes = np.array(embedding).tobytes()
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data) data.append(embedding_data)
...@@ -68,21 +81,25 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -68,21 +81,25 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_modules=None, lora_modules=None,
prompt_adapters=None, prompt_adapters=None,
request_logger=request_logger) request_logger=request_logger)
self._check_embedding_mode(model_config.embedding_mode) self._enabled = self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(self, request: EmbeddingRequest, async def create_embedding(
raw_request: Request): self,
request: EmbeddingRequest,
raw_request: Optional[Request] = None,
) -> Union[EmbeddingResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API. for the API specification. This API mimics the OpenAI Embedding API.
""" """
if not self._enabled:
return self.create_error_response("Embedding API disabled")
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
encoding_format = (request.encoding_format encoding_format = request.encoding_format
if request.encoding_format else "float")
if request.dimensions is not None: if request.dimensions is not None:
return self.create_error_response( return self.create_error_response(
"dimensions is currently not supported") "dimensions is currently not supported")
...@@ -92,7 +109,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -92,7 +109,7 @@ class OpenAIServingEmbedding(OpenAIServing):
created_time = int(time.monotonic()) created_time = int(time.monotonic())
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncIterator[EmbeddingRequestOutput]] = [] generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try: try:
( (
lora_request, lora_request,
...@@ -137,18 +154,16 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -137,18 +154,16 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[ result_generator = merge_async_iterators(
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) *generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]] final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts) final_res_batch = [None] * len(prompts)
try: try:
async for i, res in result_generator: async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res final_res_batch[i] = res
for final_res in final_res_batch: for final_res in final_res_batch:
...@@ -160,15 +175,18 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -160,15 +175,18 @@ class OpenAIServingEmbedding(OpenAIServing):
response = request_output_to_embedding_response( response = request_output_to_embedding_response(
final_res_batch_checked, request_id, created_time, model_name, final_res_batch_checked, request_id, created_time, model_name,
encoding_format) encoding_format)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return response return response
def _check_embedding_mode(self, embedding_mode: bool): def _check_embedding_mode(self, embedding_mode: bool) -> bool:
if not embedding_mode: if not embedding_mode:
logger.warning( logger.warning(
"embedding_mode is False. Embedding API will not work.") "embedding_mode is False. Embedding API will not work.")
else: else:
logger.info("Activating the server engine with embedding enabled.") logger.info("Activating the server engine with embedding enabled.")
return embedding_mode
...@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeRequest) TokenizeRequest)
# yapf: enable # yapf: enable
from vllm.inputs import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
...@@ -31,7 +31,7 @@ from vllm.pooling_params import PoolingParams ...@@ -31,7 +31,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -2,7 +2,9 @@ from typing import List, Optional, Union ...@@ -2,7 +2,9 @@ from typing import List, Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages from vllm.entrypoints.chat_utils import (apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -70,12 +72,12 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -70,12 +72,12 @@ class OpenAIServingTokenization(OpenAIServing):
logger.warning( logger.warning(
"Multi-modal inputs are ignored during tokenization") "Multi-modal inputs are ignored during tokenization")
prompt = tokenizer.apply_chat_template( prompt = apply_chat_template(
add_generation_prompt=request.add_generation_prompt, tokenizer,
conversation=conversation, conversation=conversation,
tokenize=False, chat_template=self.chat_template,
chat_template=self.chat_template) add_generation_prompt=request.add_generation_prompt,
assert isinstance(prompt, str) )
else: else:
prompt = request.prompt prompt = request.prompt
......
import os import os
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
VLLM_HOST_IP: str = "" VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None VLLM_PORT: Optional[int] = None
VLLM_RPC_PORT: int = 5570 VLLM_RPC_BASE_PATH: str = tempfile.gettempdir()
VLLM_USE_MODELSCOPE: bool = False VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None VLLM_INSTANCE_ID: Optional[str] = None
...@@ -32,6 +33,7 @@ if TYPE_CHECKING: ...@@ -32,6 +33,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_OMP_THREADS_BIND: str = ""
...@@ -46,14 +48,21 @@ if TYPE_CHECKING: ...@@ -46,14 +48,21 @@ if TYPE_CHECKING:
VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_AUDIO_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False VLLM_USE_PRECOMPILED: bool = False
VLLM_NO_DEPRECATION_WARNING: bool = False VLLM_NO_DEPRECATION_WARNING: bool = False
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
def get_default_cache_root(): def get_default_cache_root():
...@@ -132,7 +141,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -132,7 +141,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
os.path.join(get_default_cache_root(), "vllm"), os.path.join(get_default_cache_root(), "vllm"),
)), )),
# used in distributed environment to determine the master address # used in distributed environment to determine the ip address
# of the current node, when the node has multiple network interfaces.
# If you are using multi-node inference, you should set this differently
# on each node.
'VLLM_HOST_IP': 'VLLM_HOST_IP':
lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""),
...@@ -145,10 +157,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -145,10 +157,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: int(os.getenv('VLLM_PORT', '0')) lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None, if 'VLLM_PORT' in os.environ else None,
# used when the frontend api server is running in multi-processing mode, # path used for ipc when the frontend api server is running in
# to communicate with the backend engine process over ZMQ. # multi-processing mode to communicate with the backend engine process.
'VLLM_RPC_PORT': 'VLLM_RPC_BASE_PATH':
lambda: int(os.getenv('VLLM_PORT', '5570')), lambda: os.getenv('VLLM_RPC_BASE_PATH', tempfile.gettempdir()),
# If true, will load models from ModelScope instead of Hugging Face Hub. # If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers # note that the value is true or false, not numbers
...@@ -268,6 +280,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -268,6 +280,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ATTENTION_BACKEND": "VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
# Pipeline stage partition strategy # Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION": "VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
...@@ -336,12 +352,17 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -336,12 +352,17 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_IMAGE_FETCH_TIMEOUT": "VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Timeout for fetching audio when serving multimodal models
# Default is 5 seconds
"VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "5")),
# Path to the XLA persistent cache directory. # Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs. # Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH": "VLLM_XLA_CACHE_PATH":
lambda: os.path.expanduser( lambda: os.path.expanduser(
os.getenv( os.getenv(
"VLLM_ASSETS_CACHE", "VLLM_XLA_CACHE_PATH",
os.path.join(get_default_cache_root(), "vllm", "xla_cache"), os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
)), )),
"VLLM_FUSED_MOE_CHUNK_SIZE": "VLLM_FUSED_MOE_CHUNK_SIZE":
...@@ -351,6 +372,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -351,6 +372,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_NO_DEPRECATION_WARNING": "VLLM_NO_DEPRECATION_WARNING":
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),
# If set, the OpenAI API server will stay alive even after the underlying
# AsyncLLMEngine errors and stops serving requests
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH":
lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)),
# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
# the user to specify a max sequence length greater than # the user to specify a max sequence length greater than
# the max length derived from the model's config.json. # the max length derived from the model's config.json.
...@@ -359,6 +385,39 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -359,6 +385,39 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: lambda:
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
("1", "true")), ("1", "true")),
# If set, forces FP8 Marlin to be used for FP8 quantization regardless
# of the hardware support for FP8 compute.
"VLLM_TEST_FORCE_FP8_MARLIN":
lambda:
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
# If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed.
# See https://github.com/vllm-project/vllm/issues/7045
"VLLM_ALLOW_ENGINE_USE_RAY":
lambda:
(os.environ.get("VLLM_ALLOW_ENGINE_USE_RAY", "0").strip().lower() in
("1", "true")),
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded
"VLLM_PLUGINS":
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
"VLLM_PLUGINS"].split(","),
# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -13,7 +13,7 @@ from vllm.logger import init_logger ...@@ -13,7 +13,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_open_port, from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
...@@ -141,7 +141,6 @@ class CPUExecutor(ExecutorBase): ...@@ -141,7 +141,6 @@ class CPUExecutor(ExecutorBase):
rank=rank, rank=rank,
distributed_init_method=self.distributed_init_method, distributed_init_method=self.distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
prompt_adapter_config=self.prompt_adapter_config, prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
...@@ -332,7 +331,6 @@ def _verify_and_get_scheduler_config( ...@@ -332,7 +331,6 @@ def _verify_and_get_scheduler_config(
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
_GB = 1 << 30
if config.enable_prefix_caching: if config.enable_prefix_caching:
logger.warning("Prefix caching is not supported on CPU, disable it.") logger.warning("Prefix caching is not supported on CPU, disable it.")
config.enable_prefix_caching = False config.enable_prefix_caching = False
...@@ -341,11 +339,11 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: ...@@ -341,11 +339,11 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
if kv_cache_space >= 0: if kv_cache_space >= 0:
if kv_cache_space == 0: if kv_cache_space == 0:
config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
"for CPU backend is not set, using 4 by default.") "for CPU backend is not set, using 4 by default.")
else: else:
config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
else: else:
raise RuntimeError( raise RuntimeError(
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE" "Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
......
...@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod ...@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -29,9 +29,9 @@ class ExecutorBase(ABC): ...@@ -29,9 +29,9 @@ class ExecutorBase(ABC):
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig], prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
...@@ -40,10 +40,9 @@ class ExecutorBase(ABC): ...@@ -40,10 +40,9 @@ class ExecutorBase(ABC):
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.multimodal_config = multimodal_config
self.speculative_config = speculative_config self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
self._init_executor() self._init_executor()
@abstractmethod @abstractmethod
......
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -7,15 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -7,15 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
def create_worker(worker_module_name, worker_class_name, **kwargs): def create_worker(worker_module_name: str, worker_class_name: str,
worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
**kwargs):
wrapper = WorkerWrapperBase( wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name, worker_module_name=worker_module_name,
worker_class_name=worker_class_name, worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
) )
wrapper.init_worker(**kwargs) wrapper.init_worker(**kwargs)
return wrapper.worker return wrapper.worker
...@@ -55,13 +58,27 @@ class GPUExecutor(ExecutorBase): ...@@ -55,13 +58,27 @@ class GPUExecutor(ExecutorBase):
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
speculative_config=self.speculative_config, speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config, prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=(not self.parallel_config) is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0), or (rank % self.parallel_config.tensor_parallel_size == 0),
observability_config=self.observability_config,
) )
def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_class_fn = None
if self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_worker"
worker_class_name = "MultiStepWorker"
elif self.speculative_config:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
return (worker_module_name, worker_class_name, worker_class_fn)
def _get_create_worker_kwargs( def _get_create_worker_kwargs(
self, self,
local_rank: int = 0, local_rank: int = 0,
...@@ -69,13 +86,15 @@ class GPUExecutor(ExecutorBase): ...@@ -69,13 +86,15 @@ class GPUExecutor(ExecutorBase):
distributed_init_method: Optional[str] = None) -> Dict: distributed_init_method: Optional[str] = None) -> Dict:
worker_kwargs = self._get_worker_kwargs(local_rank, rank, worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method) distributed_init_method)
if self.speculative_config is None:
worker_kwargs.update(worker_module_name="vllm.worker.worker", (worker_module_name, worker_class_name,
worker_class_name="Worker") worker_class_fn) = self._get_worker_module_and_class()
else: worker_kwargs.update(
worker_kwargs.update( worker_module_name=worker_module_name,
worker_module_name="vllm.spec_decode.spec_decode_worker", worker_class_name=worker_class_name,
worker_class_name="create_spec_worker") worker_class_fn=worker_class_fn,
)
return worker_kwargs return worker_kwargs
def _create_worker(self, def _create_worker(self,
......
from array import array
from typing import Any, Type
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if isinstance(obj, array):
assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}.")
return obj.tobytes()
def decode_hook(type: Type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if type is array:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj)
return deserialized
...@@ -100,9 +100,8 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): ...@@ -100,9 +100,8 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
output = await make_async( output = await make_async(self.driver_worker.execute_model
self.driver_worker.execute_model )(execute_model_req=execute_model_req, )
)(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
return output return output
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
......
...@@ -10,8 +10,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase ...@@ -10,8 +10,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
make_async) get_open_port, make_async)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -49,7 +49,6 @@ class OpenVINOExecutor(ExecutorBase): ...@@ -49,7 +49,6 @@ class OpenVINOExecutor(ExecutorBase):
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True, is_driver_worker=True,
) )
...@@ -165,14 +164,13 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: ...@@ -165,14 +164,13 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0: if kv_cache_space >= 0:
_GB = 1 << 30
if kv_cache_space == 0: if kv_cache_space == 0:
config.openvino_kvcache_space_bytes = 4 * _GB # type: ignore config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning( logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
"for OpenVINO backend is not set, using 4 by default.") "for OpenVINO backend is not set, using 4 by default.")
else: else:
config.openvino_kvcache_space_bytes = kv_cache_space * _GB # type: ignore config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
else: else:
raise RuntimeError( raise RuntimeError(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
......
...@@ -4,9 +4,12 @@ from collections import defaultdict ...@@ -4,9 +4,12 @@ from collections import defaultdict
from itertools import islice, repeat from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import msgspec
import vllm.envs as envs import vllm.envs as envs
from vllm.executor.distributed_gpu_executor import ( # yapf: disable from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync) DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
...@@ -60,6 +63,18 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -60,6 +63,18 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers. # Create the parallel GPU workers.
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None
def _configure_ray_workers_use_nsight(self, def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]: ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling # If nsight profiling is enabled, we need to set the profiling
...@@ -76,19 +91,20 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -76,19 +91,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
return ray_remote_kwargs return ray_remote_kwargs
def _get_worker_wrapper_args(self) -> Dict[str, Any]: def _get_worker_wrapper_args(self) -> Dict[str, Any]:
if self.speculative_config is not None: (worker_module_name, worker_class_name,
worker_module_name = "vllm.spec_decode.spec_decode_worker" worker_class_fn) = self._get_worker_module_and_class()
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
return dict( return dict(
worker_module_name=worker_module_name, worker_module_name=worker_module_name,
worker_class_name=worker_class_name, worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
# child class could overwrite this to return actual env vars.
def _get_env_vars_to_be_updated(self):
return self._env_vars_for_all_workers
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if (self.parallel_config.tensor_parallel_size == 1 if (self.parallel_config.tensor_parallel_size == 1
...@@ -115,9 +131,9 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -115,9 +131,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
ray_remote_kwargs) ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
logger.info("driver_ip: %s", driver_ip)
worker_wrapper_kwargs = self._get_worker_wrapper_args() worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs): for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0): if not bundle.get("GPU", 0):
...@@ -202,6 +218,19 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -202,6 +218,19 @@ class RayGPUExecutor(DistributedGPUExecutor):
for node_id, gpu_ids in node_gpus.items(): for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids) node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node.")
VLLM_INSTANCE_ID = get_vllm_instance_id() VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers. # Set environment variables for the driver and workers.
...@@ -213,8 +242,12 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -213,8 +242,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
"VLLM_TRACE_FUNCTION": "VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION), str(envs.VLLM_TRACE_FUNCTION),
}, ) for (node_id, _) in worker_node_and_gpu_ids] }, ) for (node_id, _) in worker_node_and_gpu_ids]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables", self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables) all_args=self._get_env_vars_to_be_updated())
if len(node_gpus) == 1: if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address. # in single node case, we don't need to get the IP address.
...@@ -297,8 +330,10 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -297,8 +330,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
if self.forward_dag is None: if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
outputs = ray.get(self.forward_dag.execute(execute_model_req)) serialized_data = self.input_encoder.encode(execute_model_req)
return outputs[0] outputs = ray.get(self.forward_dag.execute(serialized_data))
output = self.output_decoder.decode(outputs[0])
return output
def _run_workers( def _run_workers(
self, self,
...@@ -446,11 +481,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -446,11 +481,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self): def __del__(self):
if self.forward_dag is not None: self.shutdown()
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
...@@ -472,9 +503,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): ...@@ -472,9 +503,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
if self.forward_dag is None: if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
dag_future = await self.forward_dag.execute_async(execute_model_req) serialized_data = self.input_encoder.encode(execute_model_req)
dag_future = await self.forward_dag.execute_async(serialized_data)
outputs = await dag_future outputs = await dag_future
return outputs[0] return self.output_decoder.decode(outputs[0])
async def _driver_execute_model_async( async def _driver_execute_model_async(
self, self,
...@@ -523,8 +555,4 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): ...@@ -523,8 +555,4 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
return await asyncio.gather(*coros) return await asyncio.gather(*coros)
def __del__(self): def __del__(self):
if self.forward_dag is not None: self.shutdown()
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
from typing import List, Optional, Tuple, Union import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import msgspec
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu from vllm.utils import get_ip, is_hip, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
PG_WAIT_TIMEOUT = 1800
try: try:
import ray import ray
from ray._private.state import available_resources_per_node
from ray.util import placement_group_table
from ray.util.placement_group import PlacementGroup
class RayWorkerWrapper(WorkerWrapperBase): class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
...@@ -23,6 +33,10 @@ try: ...@@ -23,6 +33,10 @@ try:
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
def get_node_ip(self) -> str: def get_node_ip(self) -> str:
return get_ip() return get_ip()
...@@ -32,16 +46,26 @@ try: ...@@ -32,16 +46,26 @@ try:
return node_id, gpu_ids return node_id, gpu_ids
def execute_model_spmd( def execute_model_spmd(
self, req_or_tuple: Union[ExecuteModelRequest, self, req_or_tuple: Union[bytes,
Tuple[ExecuteModelRequest, Tuple[bytes,
IntermediateTensors]]): Optional[IntermediateTensors]]]
) -> bytes:
"""Execute model in SPMD fashion: used only when SPMD worker and """Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled. compiled DAG are both enabled.
Args: Args:
req_or_tuple: The request to execute the model, or a tuple req_or_tuple: A request or a tuple containing the
containing the request and intermediate tensors. request and intermediate tensors. Intermediate tensors are
None unless if it is provided because it is > 0 pipeline
stage. The request is serialized by msgspec.
""" """
if isinstance(req_or_tuple, bytes):
serialized_req, intermediate_tensors = req_or_tuple, None
else:
serialized_req, intermediate_tensors = req_or_tuple
execute_model_req = self.input_decoder.decode(serialized_req)
# TODO(swang): This is needed right now because Ray aDAG executes # TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current # on a background thread, so we need to reset torch's current
# device. # device.
...@@ -50,16 +74,14 @@ try: ...@@ -50,16 +74,14 @@ try:
torch.cuda.set_device(self.worker.device) torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
if isinstance(req_or_tuple, tuple):
execute_model_req, intermediate_tensors = req_or_tuple
else:
execute_model_req = req_or_tuple
intermediate_tensors = None
output = self.worker._execute_model_spmd(execute_model_req, output = self.worker._execute_model_spmd(execute_model_req,
intermediate_tensors) intermediate_tensors)
# Pipeline model request and output to the next pipeline stage.
if isinstance(output, IntermediateTensors): if isinstance(output, IntermediateTensors):
return execute_model_req, output output = serialized_req, output
else:
output = self.output_encoder.encode(output)
return output return output
ray_import_err = None ray_import_err = None
...@@ -82,6 +104,106 @@ def assert_ray_available(): ...@@ -82,6 +104,106 @@ def assert_ray_available():
"`pip install ray`.") from ray_import_err "`pip install ray`.") from ray_import_err
def _verify_bundles(placement_group: "PlacementGroup",
parallel_config: ParallelConfig, device_str: str):
"""Verify a given placement group has bundles located in the right place.
There are 2 rules.
- Warn if all tensor parallel workers cannot fit in a single node.
- Fail if driver node is not included in a placement group.
"""
assert ray.is_initialized(), (
"Ray is not initialized although distributed-executor-backend is ray.")
pg_data = placement_group_table(placement_group)
# bundle_idx -> node_id
bundle_to_node_ids = pg_data["bundles_to_node_id"]
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles = pg_data["bundles"]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
for bundle_idx, node_id in bundle_to_node_ids.items():
node_id_to_bundle[node_id].append(bundles[bundle_idx])
driver_node_id = ray.get_runtime_context().get_node_id()
if driver_node_id not in node_id_to_bundle:
raise RuntimeError(
f"driver node id {driver_node_id} is not included in a placement "
f"group {placement_group.id}. Node id -> bundles "
f"{node_id_to_bundle}. "
"You don't have enough GPUs available in a current node. Check "
"`ray status` to see if you have available GPUs in a node "
f"{driver_node_id} before starting an vLLM engine.")
for node_id, bundles in node_id_to_bundle.items():
if len(bundles) < parallel_config.tensor_parallel_size:
logger.warning(
"tensor_parallel_size=%d "
"is bigger than a reserved number of %ss (%d "
"%ss) in a node %s. Tensor parallel workers can be "
"spread out to 2+ nodes which can degrade the performance "
"unless you have fast interconnect across nodes, like "
"Infiniband. To resolve this issue, make sure you have more "
"than %d GPUs available at each node.",
parallel_config.tensor_parallel_size, device_str, len(bundles),
device_str, node_id, parallel_config.tensor_parallel_size)
def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
"""Wait until a placement group is ready.
It prints the informative log messages if the placement group is
not created within time.
"""
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
placement_group_specs = current_placement_group.bundle_specs
s = time.time()
pg_ready_ref = current_placement_group.ready()
wait_interval = 10
while time.time() - s < PG_WAIT_TIMEOUT:
ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
if len(ready) > 0:
break
# Exponential backoff for warning print.
wait_interval *= 2
logger.info(
"Waiting for creating a placement group of specs for "
"%d seconds. specs=%s. Check "
"`ray status` to see if you have enough resources.",
int(time.time() - s), placement_group_specs)
try:
ray.get(pg_ready_ref, timeout=0)
except ray.exceptions.GetTimeoutError:
raise ValueError(
"Cannot provide a placement group of "
f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
"`ray status` to make sure the cluster has enough resources."
) from None
def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
ray.util.remove_placement_group(current_placement_group)
s = time.time()
wait_interval = 10
while time.time() - s < PG_WAIT_TIMEOUT:
pg = ray.util.get_current_placement_group()
if pg is None:
break
# Exponential backoff for warning print.
wait_interval *= 2
logger.info(
"Waiting for removing a placement group of specs for "
"%d seconds.", int(time.time() - s))
time.sleep(wait_interval)
def initialize_ray_cluster( def initialize_ray_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
ray_address: Optional[str] = None, ray_address: Optional[str] = None,
...@@ -111,7 +233,7 @@ def initialize_ray_cluster( ...@@ -111,7 +233,7 @@ def initialize_ray_cluster(
# Placement group is already set. # Placement group is already set.
return return
device_str = "GPU" if not is_tpu() else "TPU" device_str = "GPU" if not current_platform.is_tpu() else "TPU"
# Create placement group for worker processes # Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group() current_placement_group = ray.util.get_current_placement_group()
if current_placement_group: if current_placement_group:
...@@ -140,15 +262,32 @@ def initialize_ray_cluster( ...@@ -140,15 +262,32 @@ def initialize_ray_cluster(
f"The number of required {device_str}s exceeds the total " f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group.") f"number of available {device_str}s in the placement group.")
# Create a new placement group # Create a new placement group
placement_group_specs = ([{ placement_group_specs: List[Dict[str, float]] = ([{
device_str: 1 device_str: 1.0
}] * parallel_config.world_size) } for _ in range(parallel_config.world_size)])
# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
# the current node has at least one device.
current_ip = get_ip()
current_node_id = ray.get_runtime_context().get_node_id()
current_node_resource = available_resources_per_node()[current_node_id]
if current_node_resource.get(device_str, 0) < 1:
raise ValueError(
f"Current node has no {device_str} available. "
f"{current_node_resource=}. vLLM engine cannot start without "
f"{device_str}. Make sure you have at least 1 {device_str} "
f"available in a node {current_node_id=} {current_ip=}.")
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs[0][f"node:{current_ip}"] = 0.001
# By default, Ray packs resources as much as possible.
current_placement_group = ray.util.placement_group( current_placement_group = ray.util.placement_group(
placement_group_specs) placement_group_specs, strategy="PACK")
# Wait until PG is ready - this will block until all _wait_until_pg_ready(current_placement_group)
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)
assert current_placement_group is not None
_verify_bundles(current_placement_group, parallel_config, device_str)
# Set the placement group in the parallel config # Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group parallel_config.placement_group = current_placement_group
import asyncio import asyncio
import os from typing import List, Optional
from collections import defaultdict
from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
Tuple, Union)
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
ModelConfig, MultiModalConfig, ParallelConfig, from vllm.executor.xpu_executor import XPUExecutor
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.utils import get_vllm_instance_id, make_async
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__) logger = init_logger(__name__)
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayXPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
assert device_config.device_type == "xpu"
assert (not speculative_config
), "Speculative decoding not yet supported for XPU backend"
self.model_config = model_config
self.cache_config = cache_config
self.load_config = load_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
def _init_executor(self) -> None:
pass
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- Tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", )
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_gpu_blocks, num_cpu_blocks
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
return dict(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
# For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization
else:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus = 1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
def _get_env_vars_to_be_updated(self):
# Get the set of GPU IDs used on each node. # Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True) use_dummy_driver=True)
node_workers = defaultdict(list) VLLM_INSTANCE_ID = get_vllm_instance_id()
node_gpus = defaultdict(list)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
# TODO: add env var for xpu
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
def collect_arg_helper_func(**kwargs):
# avoid writing `{"name": value}` manually
return kwargs
init_worker_all_kwargs = []
# Initialize the actual workers inside worker wrapper.
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
local_rank = node_workers[node_id].index(rank)
init_worker_all_kwargs.append(
collect_arg_helper_func(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
is_driver_worker=rank == 0,
))
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device") # Set environment variables for the driver and workers.
self._run_workers( all_args_to_update_environment_variables = [({
"load_model", "VLLM_INSTANCE_ID":
max_concurrent_workers=self.parallel_config. VLLM_INSTANCE_ID,
max_parallel_loading_workers, "VLLM_TRACE_FUNCTION":
) str(envs.VLLM_TRACE_FUNCTION),
}, ) for (_, _) in worker_node_and_gpu_ids]
return all_args_to_update_environment_variables
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is class RayXPUExecutorAsync(RayXPUExecutor, RayGPUExecutorAsync):
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info("# GPU blocks: %d, "
"# CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
def _run_workers(
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
ways:
- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs
driver_worker_output = []
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
else:
assert self.driver_dummy_worker is not None
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self, enable_asyncio: bool):
import pkg_resources
from packaging import version
required_version = version.parse("2.32")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.
bind( # type: ignore[attr-defined]
input_data) for worker in self.workers
])
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def _check_if_any_actor_is_dead(self):
if not self.workers:
return
dead_actors = []
for actor in self.workers:
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
if actor_state["State"] == "DEAD":
dead_actors.append(actor)
if dead_actors:
raise RuntimeError("At least one Worker is dead. "
f"Dead Workers: {dead_actors}. ")
class RayXPUExecutorAsync(RayXPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.driver_exec_method = make_async(self.driver_worker.execute_method) self.driver_exec_method = make_async(self.driver_worker.execute_method)
self.pp_locks: Optional[List[asyncio.Lock]] = None
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req)
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)
...@@ -52,7 +52,6 @@ class TPUExecutor(ExecutorBase): ...@@ -52,7 +52,6 @@ class TPUExecutor(ExecutorBase):
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
multimodal_config=self.multimodal_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
) )
......
from typing import List, Optional from typing import Callable, List, Optional, Tuple, Type, Union
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -29,9 +29,9 @@ class XPUExecutor(GPUExecutor): ...@@ -29,9 +29,9 @@ class XPUExecutor(GPUExecutor):
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
prompt_adapter_config: Optional[PromptAdapterConfig], prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
observability_config: Optional[ObservabilityConfig],
) -> None: ) -> None:
assert device_config.device_type == "xpu" assert device_config.device_type == "xpu"
assert (not speculative_config assert (not speculative_config
...@@ -46,35 +46,27 @@ class XPUExecutor(GPUExecutor): ...@@ -46,35 +46,27 @@ class XPUExecutor(GPUExecutor):
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None self.speculative_config = None
self.observability_config = observability_config
# Instantiate the worker and load the model to GPU. # Instantiate the worker and load the model to GPU.
self._init_executor() self._init_executor()
def _create_worker(self, def _get_worker_module_and_class(
local_rank: int = 0, self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
rank: int = 0, worker_class_fn = None
distributed_init_method: Optional[str] = None): if self.speculative_config is not None:
if self.speculative_config is None:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
else:
raise NotImplementedError( raise NotImplementedError(
"XPU does not support speculative decoding") "XPU does not support speculative decoding")
else:
wrapper = WorkerWrapperBase( worker_module_name = "vllm.worker.xpu_worker"
worker_module_name=worker_module_name, worker_class_name = "XPUWorker"
worker_class_name=worker_class_name, return (worker_module_name, worker_class_name, worker_class_fn)
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def execute_model( def execute_model(
self, self, execute_model_req: ExecuteModelRequest
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req) output = self.driver_worker.execute_model(execute_model_req)
return output return output
......
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
TextPrompt, TokensPrompt, parse_and_batch_prompt) LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
...@@ -12,7 +14,17 @@ See also: ...@@ -12,7 +14,17 @@ See also:
""" """
__all__ = [ __all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", "TextPrompt",
"TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY", "TokensPrompt",
"InputContext", "InputRegistry" "PromptInputs",
"SingletonPromptInputs",
"ExplicitEncoderDecoderPrompt",
"LLMInputs",
"EncoderDecoderLLMInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"INPUT_REGISTRY",
"InputContext",
"InputRegistry",
] ]
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple,
TypedDict, Union, cast, overload) Union)
from typing_extensions import NotRequired from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict): class TextPrompt(TypedDict):
"""Schema for a text prompt.""" """Schema for a text prompt."""
...@@ -92,12 +33,71 @@ class TokensPrompt(TypedDict): ...@@ -92,12 +33,71 @@ class TokensPrompt(TypedDict):
""" """
PromptInputs = Union[str, TextPrompt, TokensPrompt] SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
""" """
The inputs to the LLM, which can take one of the following forms: Set of possible schemas for a single LLM input:
- A text prompt (:class:`str` or :class:`TextPrompt`) - A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`) - A tokenized prompt (:class:`TokensPrompt`)
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPromptInputs` may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
_T1_co = TypeVar("_T1_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
_T2_co = TypeVar("_T2_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
"""Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
:class:`SingletonPromptInputs` schemas, and are not
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
:class:`SingletonPromptInputs` instances.
"""
encoder_prompt: _T1_co
decoder_prompt: Optional[_T2_co]
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- A single data structure containing both an encoder and a decoder prompt
(:class:`ExplicitEncoderDecoderPrompt`)
""" """
...@@ -105,6 +105,8 @@ class LLMInputs(TypedDict): ...@@ -105,6 +105,8 @@ class LLMInputs(TypedDict):
""" """
The inputs in :class:`~vllm.LLMEngine` before they are The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor. passed to the model executor.
This specifies the data required for decoder-only models.
""" """
prompt_token_ids: List[int] prompt_token_ids: List[int]
"""The token IDs of the prompt.""" """The token IDs of the prompt."""
...@@ -119,3 +121,58 @@ class LLMInputs(TypedDict): ...@@ -119,3 +121,58 @@ class LLMInputs(TypedDict):
Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
""" """
class EncoderDecoderLLMInputs(LLMInputs):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder_prompt_token_ids: List[int]
"""The token IDs of the encoder prompt."""
encoder_prompt: NotRequired[Optional[str]]
"""
The original encoder prompt text corresponding to the token IDs, if
available.
"""
_T1 = TypeVar("_T1",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
_T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
def build_explicit_enc_dec_prompt(
encoder_prompt: _T1,
decoder_prompt: Optional[_T2],
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1],
dec_prompts: Iterable[Optional[_T2]],
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Zip encoder and decoder prompts together into a list of
:class:`ExplicitEncoderDecoderPrompt` instances.
"""
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
) -> List[Tuple[_T1, Optional[_T2]]]:
return [(enc_dec_prompt["encoder_prompt"],
enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts]
from typing import List, Literal, Sequence, TypedDict, Union, overload
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs)
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt, str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False) for elem in prompt
]
if is_list_of(prompt, int):
# case 3: array of tokens
return [ParsedTokens(content=prompt, is_tokens=True)]
if is_list_of(prompt, list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt[0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in prompt
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
def is_explicit_encoder_decoder_prompt(
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(inputs, dict) and "encoder_prompt" in inputs
def is_valid_encoder_decoder_llm_inputs(
inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
) -> TypeIs[EncoderDecoderLLMInputs]:
return "encoder_prompt_token_ids" in inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment