Commit f48954a4 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.5.0

parents 1dba29d3 8f89d720
...@@ -29,23 +29,32 @@ class AsyncEngineDeadError(RuntimeError): ...@@ -29,23 +29,32 @@ class AsyncEngineDeadError(RuntimeError):
pass pass
def _raise_exception_on_finish( def _log_task_completion(task: asyncio.Task,
task: asyncio.Task, error_callback: Callable[[Exception], error_callback: Callable[[Exception], None]) -> None:
None]) -> None: """This function is only intended for the `engine.run_engine_loop()` task.
msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.") In particular, that task runs a `while True` loop that can only exit if
there is an exception.
"""
exception = None exception = None
try: try:
task.result() return_value = task.result()
# NOTE: This will be thrown if task exits normally (which it should not) raise AssertionError(
raise AsyncEngineDeadError(msg) f"The engine background task should never finish without an "
f"exception. {return_value}")
except asyncio.exceptions.CancelledError:
# We assume that if the task is cancelled, we are gracefully shutting
# down. This should only happen on program exit.
logger.info("Engine is gracefully shutting down.")
except Exception as e: except Exception as e:
exception = e exception = e
logger.error("Engine background task failed", exc_info=e) logger.error("Engine background task failed", exc_info=e)
error_callback(exception) error_callback(exception)
raise AsyncEngineDeadError( raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from e "Task finished unexpectedly. This should never happen! "
"Please open an issue on Github. See stack trace above for the"
"actual cause.") from e
class AsyncStream: class AsyncStream:
...@@ -438,8 +447,7 @@ class AsyncLLMEngine: ...@@ -438,8 +447,7 @@ class AsyncLLMEngine:
self._background_loop_unshielded = asyncio.get_event_loop( self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop()) ).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback( self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish, partial(_log_task_completion, error_callback=self._error_callback))
error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded) self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args, def _init_engine(self, *args,
......
...@@ -162,7 +162,7 @@ class LLMEngine: ...@@ -162,7 +162,7 @@ class LLMEngine:
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, " "model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"rope_scaling=%r, tokenizer_revision=%s, " "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, " "disable_custom_all_reduce=%s, quantization=%s, "
...@@ -177,6 +177,7 @@ class LLMEngine: ...@@ -177,6 +177,7 @@ class LLMEngine:
model_config.tokenizer_mode, model_config.tokenizer_mode,
model_config.revision, model_config.revision,
model_config.rope_scaling, model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision, model_config.tokenizer_revision,
model_config.trust_remote_code, model_config.trust_remote_code,
model_config.dtype, model_config.dtype,
......
...@@ -78,7 +78,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -78,7 +78,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group, we can take the # Since there's only one sequence per sequence group, we can take the
# first sample. # first sample.
samples = [outputs[step].samples[0] for step in range(len(outputs))] samples = [output.samples[0] for output in outputs]
# -1 means the output token is not valid (eg. due to spec decode # -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens). # rejecting tokens).
......
...@@ -60,10 +60,10 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -60,10 +60,10 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
assert len(outputs) == 1, ("Single step should only has 1 output.") assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0] output = outputs[0]
prompt_logprobs = output.prompt_logprobs prompt_logprobs = output.prompt_logprobs
if (prompt_logprobs is not None if prompt_logprobs is not None:
and seq_group.sampling_params.detokenize and self.detokenizer): if seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace( self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs) seq_group, prompt_logprobs)
if not seq_group.prompt_logprobs: if not seq_group.prompt_logprobs:
# The first prompt token's logprob is None because it doesn't # The first prompt token's logprob is None because it doesn't
# have tokens that are precedent. # have tokens that are precedent.
......
...@@ -14,7 +14,7 @@ from vllm.lora.request import LoRARequest ...@@ -14,7 +14,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs from vllm.utils import Counter, deprecate_kwargs
...@@ -153,7 +153,14 @@ class LLM: ...@@ -153,7 +153,14 @@ class LLM:
self, self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None: ) -> None:
self.llm_engine.tokenizer.tokenizer = tokenizer # While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"):
self.llm_engine.tokenizer.tokenizer = tokenizer
else:
self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
tokenizer)
@overload # LEGACY: single (prompt + optional token ids) @overload # LEGACY: single (prompt + optional token ids)
def generate( def generate(
...@@ -163,8 +170,7 @@ class LLM: ...@@ -163,8 +170,7 @@ class LLM:
List[SamplingParams]]] = None, List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
... ...
...@@ -176,8 +182,7 @@ class LLM: ...@@ -176,8 +182,7 @@ class LLM:
List[SamplingParams]]] = None, List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
... ...
...@@ -190,8 +195,7 @@ class LLM: ...@@ -190,8 +195,7 @@ class LLM:
*, *,
prompt_token_ids: List[int], prompt_token_ids: List[int],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
... ...
...@@ -204,8 +208,7 @@ class LLM: ...@@ -204,8 +208,7 @@ class LLM:
*, *,
prompt_token_ids: List[List[int]], prompt_token_ids: List[List[int]],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
... ...
...@@ -216,8 +219,7 @@ class LLM: ...@@ -216,8 +219,7 @@ class LLM:
sampling_params: None, sampling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]], prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
... ...
...@@ -230,13 +232,12 @@ class LLM: ...@@ -230,13 +232,12 @@ class LLM:
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None, Sequence[SamplingParams]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
... ...
@deprecate_kwargs("prompts", @deprecate_kwargs("prompts",
"prompt_token_ids", "prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY, is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter " additional_message="Please use the 'inputs' parameter "
"instead.") "instead.")
...@@ -248,8 +249,7 @@ class LLM: ...@@ -248,8 +249,7 @@ class LLM:
Sequence[SamplingParams]]] = None, Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -276,11 +276,15 @@ class LLM: ...@@ -276,11 +276,15 @@ class LLM:
considered legacy and may be deprecated in the future. You should considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter. instead pass them via the ``inputs`` parameter.
""" """
if prompt_token_ids is not None or multi_modal_data is not None: if self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.generate() is only supported for generation models "
"(XForCausalLM).")
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs( inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts), prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
) )
else: else:
inputs = cast( inputs = cast(
...@@ -308,8 +312,7 @@ class LLM: ...@@ -308,8 +312,7 @@ class LLM:
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
... ...
...@@ -321,8 +324,7 @@ class LLM: ...@@ -321,8 +324,7 @@ class LLM:
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
... ...
...@@ -335,8 +337,7 @@ class LLM: ...@@ -335,8 +337,7 @@ class LLM:
*, *,
prompt_token_ids: List[int], prompt_token_ids: List[int],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
... ...
...@@ -349,8 +350,7 @@ class LLM: ...@@ -349,8 +350,7 @@ class LLM:
*, *,
prompt_token_ids: List[List[int]], prompt_token_ids: List[List[int]],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
... ...
...@@ -361,8 +361,7 @@ class LLM: ...@@ -361,8 +361,7 @@ class LLM:
pooling_params: None, pooling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]], prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
... ...
...@@ -375,13 +374,12 @@ class LLM: ...@@ -375,13 +374,12 @@ class LLM:
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
... ...
@deprecate_kwargs("prompts", @deprecate_kwargs("prompts",
"prompt_token_ids", "prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY, is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter " additional_message="Please use the 'inputs' parameter "
"instead.") "instead.")
...@@ -393,8 +391,7 @@ class LLM: ...@@ -393,8 +391,7 @@ class LLM:
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -420,11 +417,15 @@ class LLM: ...@@ -420,11 +417,15 @@ class LLM:
considered legacy and may be deprecated in the future. You should considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter. instead pass them via the ``inputs`` parameter.
""" """
if prompt_token_ids is not None or multi_modal_data is not None: if not self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.encode() is only supported for embedding models (XModel)."
)
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs( inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts), prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
) )
else: else:
inputs = cast( inputs = cast(
...@@ -449,7 +450,6 @@ class LLM: ...@@ -449,7 +450,6 @@ class LLM:
self, self,
prompts: Optional[Union[str, List[str]]], prompts: Optional[Union[str, List[str]]],
prompt_token_ids: Optional[Union[List[int], List[List[int]]]], prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
multi_modal_data: Optional[MultiModalData],
): ):
# skip_tokenizer_init is now checked in engine # skip_tokenizer_init is now checked in engine
...@@ -489,9 +489,6 @@ class LLM: ...@@ -489,9 +489,6 @@ class LLM:
else: else:
raise AssertionError raise AssertionError
if multi_modal_data is not None:
item["multi_modal_data"] = multi_modal_data
inputs.append(item) inputs.append(item)
return inputs return inputs
...@@ -501,7 +498,7 @@ class LLM: ...@@ -501,7 +498,7 @@ class LLM:
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]], Sequence[PoolingParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
) -> None: ) -> None:
if isinstance(inputs, (str, dict)): if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
...@@ -512,20 +509,25 @@ class LLM: ...@@ -512,20 +509,25 @@ class LLM:
if isinstance(params, list) and len(params) != num_requests: if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params " raise ValueError("The lengths of prompts and params "
"must be the same.") "must be the same.")
if isinstance(lora_request,
list) and len(lora_request) != num_requests:
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")
# Add requests to the engine. # Add requests to the engine.
for i, request_inputs in enumerate(inputs): for i, request_inputs in enumerate(inputs):
self._add_request( self._add_request(
request_inputs, request_inputs,
params[i] if isinstance(params, Sequence) else params, params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request, lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
) )
def _add_request( def _add_request(
self, self,
inputs: PromptInputs, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, self.llm_engine.add_request(request_id,
......
...@@ -36,7 +36,7 @@ openai_serving_chat: OpenAIServingChat ...@@ -36,7 +36,7 @@ openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding openai_serving_embedding: OpenAIServingEmbedding
logger = init_logger(__name__) logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set() _running_tasks: Set[asyncio.Task] = set()
...@@ -183,6 +183,16 @@ if __name__ == "__main__": ...@@ -183,6 +183,16 @@ if __name__ == "__main__":
served_model_names = [args.model] served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
# Enforce pixel values as image input type for vision language models
# when serving with API server
if engine_args.image_input_type is not None and \
engine_args.image_input_type.upper() != "PIXEL_VALUES":
raise ValueError(
f"Invalid image_input_type: {engine_args.image_input_type}. "
"Only --image-input-type 'pixel_values' is supported for serving "
"vision language models with the vLLM API server.")
engine = AsyncLLMEngine.from_engine_args( engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER) engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
......
...@@ -82,6 +82,7 @@ class ModelCard(OpenAIBaseModel): ...@@ -82,6 +82,7 @@ class ModelCard(OpenAIBaseModel):
owned_by: str = "vllm" owned_by: str = "vllm"
root: Optional[str] = None root: Optional[str] = None
parent: Optional[str] = None parent: Optional[str] = None
max_model_len: Optional[int] = None
permission: List[ModelPermission] = Field(default_factory=list) permission: List[ModelPermission] = Field(default_factory=list)
...@@ -101,6 +102,30 @@ class ResponseFormat(OpenAIBaseModel): ...@@ -101,6 +102,30 @@ class ResponseFormat(OpenAIBaseModel):
type: Literal["text", "json_object"] type: Literal["text", "json_object"]
class StreamOptions(OpenAIBaseModel):
include_usage: Optional[bool]
class FunctionDefinition(OpenAIBaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
class ChatCompletionToolsParam(OpenAIBaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition
class ChatCompletionNamedFunction(OpenAIBaseModel):
name: str
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
function: ChatCompletionNamedFunction
type: Literal["function"] = "function"
class ChatCompletionRequest(OpenAIBaseModel): class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create # https://platform.openai.com/docs/api-reference/chat/create
...@@ -119,8 +144,12 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -119,8 +144,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
le=torch.iinfo(torch.long).max) le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 0.7 temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
ChatCompletionNamedToolChoiceParam]] = "none"
user: Optional[str] = None user: Optional[str] = None
# doc: begin-chat-completion-sampling-params # doc: begin-chat-completion-sampling-params
...@@ -152,6 +181,15 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -152,6 +181,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"This is a parameter used by chat template in tokenizer config of the " "This is a parameter used by chat template in tokenizer config of the "
"model."), "model."),
) )
add_special_tokens: Optional[bool] = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to False (as is the "
"default)."),
)
include_stop_str_in_output: Optional[bool] = Field( include_stop_str_in_output: Optional[bool] = Field(
default=False, default=False,
description=( description=(
...@@ -236,6 +274,15 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -236,6 +274,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
logits_processors=logits_processors, logits_processors=logits_processors,
) )
@model_validator(mode='before')
@classmethod
def validate_stream_options(cls, values):
if (values.get('stream_options') is not None
and not values.get('stream')):
raise ValueError(
"stream_options can only be set if stream is true")
return values
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_guided_decoding_count(cls, data): def check_guided_decoding_count(cls, data):
...@@ -244,10 +291,27 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -244,10 +291,27 @@ class ChatCompletionRequest(OpenAIBaseModel):
"guided_regex" in data and data["guided_regex"] is not None, "guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None "guided_choice" in data and data["guided_choice"] is not None
]) ])
# you can only use one kind of guided decoding
if guide_count > 1: if guide_count > 1:
raise ValueError( raise ValueError(
"You can only use one kind of guided decoding " "You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').") "('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and "tool_choice" in data and data[
"tool_choice"] != "none":
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data
@model_validator(mode="before")
@classmethod
def check_tool_choice(cls, data):
if "tool_choice" in data and data["tool_choice"] != "none":
if not isinstance(data["tool_choice"], dict):
raise ValueError("Currently only named tools are supported.")
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")
return data return data
@model_validator(mode="before") @model_validator(mode="before")
...@@ -258,9 +322,9 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -258,9 +322,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
raise ValueError( raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true." "when using `top_logprobs`, `logprobs` must be set to true."
) )
elif not 0 <= data["top_logprobs"] <= 20: elif data["top_logprobs"] < 0:
raise ValueError( raise ValueError(
"`top_logprobs` must be a value in the interval [0, 20].") "`top_logprobs` must be a value a positive value.")
return data return data
...@@ -282,6 +346,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -282,6 +346,7 @@ class CompletionRequest(OpenAIBaseModel):
le=torch.iinfo(torch.long).max) le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None suffix: Optional[str] = None
temperature: Optional[float] = 1.0 temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
...@@ -413,9 +478,16 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -413,9 +478,16 @@ class CompletionRequest(OpenAIBaseModel):
@classmethod @classmethod
def check_logprobs(cls, data): def check_logprobs(cls, data):
if "logprobs" in data and data[ if "logprobs" in data and data[
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5: "logprobs"] is not None and not data["logprobs"] >= 0:
raise ValueError(("if passed, `logprobs` must be a value", raise ValueError("if passed, `logprobs` must be a positive value.")
" in the interval [0, 5].")) return data
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when stream is True.")
return data return data
...@@ -441,7 +513,8 @@ class CompletionLogProbs(OpenAIBaseModel): ...@@ -441,7 +513,8 @@ class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(OpenAIBaseModel): class CompletionResponseChoice(OpenAIBaseModel):
...@@ -505,9 +578,21 @@ class EmbeddingResponse(BaseModel): ...@@ -505,9 +578,21 @@ class EmbeddingResponse(BaseModel):
usage: UsageInfo usage: UsageInfo
class FunctionCall(OpenAIBaseModel):
name: str
arguments: str
class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
function: FunctionCall
class ChatMessage(OpenAIBaseModel): class ChatMessage(OpenAIBaseModel):
role: str role: str
content: str content: str
tool_calls: List[ToolCall] = Field(default_factory=list)
class ChatCompletionLogProb(OpenAIBaseModel): class ChatCompletionLogProb(OpenAIBaseModel):
...@@ -528,13 +613,13 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): ...@@ -528,13 +613,13 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
logprobs: Optional[ChatCompletionLogProbs] = None logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None stop_reason: Optional[Union[int, str]] = None
class ChatCompletionResponse(OpenAIBaseModel): class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion" object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: List[ChatCompletionResponseChoice]
...@@ -544,19 +629,20 @@ class ChatCompletionResponse(OpenAIBaseModel): ...@@ -544,19 +629,20 @@ class ChatCompletionResponse(OpenAIBaseModel):
class DeltaMessage(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None role: Optional[str] = None
content: Optional[str] = None content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
class ChatCompletionResponseStreamChoice(OpenAIBaseModel): class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[ChatCompletionLogProbs] = None logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None stop_reason: Optional[Union[int, str]] = None
class ChatCompletionStreamResponse(OpenAIBaseModel): class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk" object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]
......
import codecs import codecs
import time import time
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List, from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
Optional) List, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import TypedDict, Union, cast, final from typing import TypedDict, Union, cast, final
from fastapi import Request from fastapi import Request
from openai.types.chat import ChatCompletionContentPartTextParam from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam)
from vllm.config import ModelConfig from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionLogProb, ChatCompletionContentPartParam, ChatCompletionLogProb,
ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionLogProbs, ChatCompletionLogProbsContent,
ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo) FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing) OpenAIServing)
from vllm.inputs import PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.multimodal.image import ImagePixelData
from vllm.multimodal.utils import (async_get_and_parse_image,
get_full_image_text_prompt)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -39,6 +45,8 @@ class ConversationMessage(TypedDict): ...@@ -39,6 +45,8 @@ class ConversationMessage(TypedDict):
@dataclass(frozen=True) @dataclass(frozen=True)
class ChatMessageParseResult: class ChatMessageParseResult:
messages: List[ConversationMessage] messages: List[ConversationMessage]
image_futures: List[Awaitable[ImagePixelData]] = field(
default_factory=list)
class OpenAIServingChat(OpenAIServing): class OpenAIServingChat(OpenAIServing):
...@@ -93,19 +101,76 @@ class OpenAIServingChat(OpenAIServing): ...@@ -93,19 +101,76 @@ class OpenAIServingChat(OpenAIServing):
parts: Iterable[ChatCompletionContentPartParam], parts: Iterable[ChatCompletionContentPartParam],
) -> ChatMessageParseResult: ) -> ChatMessageParseResult:
texts: List[str] = [] texts: List[str] = []
image_futures: List[Awaitable[ImagePixelData]] = []
for _, part in enumerate(parts): vlm_config: Optional[VisionLanguageConfig] = getattr(
self.engine.engine, "vision_language_config", None)
model_config = getattr(self.engine.engine, "model_config", None)
for part in parts:
part_type = part["type"] part_type = part["type"]
if part_type == "text": if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"] text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text) texts.append(text)
elif part_type == "image_url":
if vlm_config is None:
raise ValueError(
"'image_url' input is not supported as the loaded "
"model is not multimodal.")
elif len(image_futures) == 0:
assert self.tokenizer is not None
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_future = async_get_and_parse_image(image_url["url"])
image_futures.append(image_future)
else:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported."
)
else: else:
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
messages = [ConversationMessage(role=role, content="\n".join(texts))] text_prompt = "\n".join(texts)
if vlm_config is not None and len(image_futures):
(image_token_prompt,
image_token_str) = vlm_config.get_image_token_text(self.tokenizer)
# NOTE: If image token string (e.g, <image>) is already present
# in the text prompt, we assume it follows the same format required
# by the engine.
if image_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
messages = [
ConversationMessage(role=role, content=text_prompt)
]
else:
full_prompt = get_full_image_text_prompt(
image_prompt=image_token_prompt,
text_prompt=text_prompt,
config=model_config)
messages = [
ConversationMessage(role=role, content=full_prompt)
]
else:
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages) return ChatMessageParseResult(messages=messages,
image_futures=image_futures)
def _parse_chat_message_content( def _parse_chat_message_content(
self, self,
...@@ -115,10 +180,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -115,10 +180,10 @@ class OpenAIServingChat(OpenAIServing):
content = message.get("content") content = message.get("content")
if content is None: if content is None:
return ChatMessageParseResult(messages=[]) return ChatMessageParseResult(messages=[], image_futures=[])
if isinstance(content, str): if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)] messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages) return ChatMessageParseResult(messages=messages, image_futures=[])
return self._parse_chat_message_content_parts(role, content) return self._parse_chat_message_content_parts(role, content)
...@@ -143,11 +208,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -143,11 +208,13 @@ class OpenAIServingChat(OpenAIServing):
try: try:
conversation: List[ConversationMessage] = [] conversation: List[ConversationMessage] = []
image_futures: List[Awaitable[ImagePixelData]] = []
for msg in request.messages: for msg in request.messages:
parsed_msg = self._parse_chat_message_content(msg) chat_parsed_result = self._parse_chat_message_content(msg)
conversation.extend(parsed_msg.messages) conversation.extend(chat_parsed_result.messages)
image_futures.extend(chat_parsed_result.image_futures)
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
conversation=conversation, conversation=conversation,
...@@ -158,11 +225,24 @@ class OpenAIServingChat(OpenAIServing): ...@@ -158,11 +225,24 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error in applying chat template from request: %s", e) logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e)) return self.create_error_response(str(e))
# Fetch image data
image_data: Optional[ImagePixelData] = None
try:
if len(image_futures):
# since we support only single image currently
assert len(image_futures) == 1
image_data = await image_futures[0]
except Exception as e:
logger.error("Error in loading image data: %s", e)
return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
try: try:
# Tokenize/detokenize depending on prompt format (string/token list) # Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize( prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt, add_special_tokens=False) request,
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
decoding_config = await self.engine.get_decoding_config() decoding_config = await self.engine.get_decoding_config()
...@@ -180,11 +260,15 @@ class OpenAIServingChat(OpenAIServing): ...@@ -180,11 +260,15 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
inputs: PromptInputs = {
"prompt": prompt_text,
"prompt_token_ids": prompt_ids,
}
if image_data is not None:
inputs["multi_modal_data"] = image_data
result_generator = self.engine.generate( result_generator = self.engine.generate(
{ inputs,
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params, sampling_params,
request_id, request_id,
lora_request, lora_request,
...@@ -244,6 +328,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -244,6 +328,9 @@ class OpenAIServingChat(OpenAIServing):
created=created_time, created=created_time,
choices=[choice_data], choices=[choice_data],
model=model_name) model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
...@@ -271,6 +358,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -271,6 +358,9 @@ class OpenAIServingChat(OpenAIServing):
choices=[choice_data], choices=[choice_data],
logprobs=None, logprobs=None,
model=model_name) model=model_name)
if (request.stream_options and
request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json( data = chunk.model_dump_json(
exclude_unset=True) exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
...@@ -283,13 +373,15 @@ class OpenAIServingChat(OpenAIServing): ...@@ -283,13 +373,15 @@ class OpenAIServingChat(OpenAIServing):
continue continue
delta_token_ids = output.token_ids[previous_num_tokens[i]:] delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[ out_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs: if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_chat_logprobs( logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids, token_ids=delta_token_ids,
top_logprobs=top_logprobs, top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs, num_output_top_logprobs=request.top_logprobs,
) )
else: else:
...@@ -298,11 +390,24 @@ class OpenAIServingChat(OpenAIServing): ...@@ -298,11 +390,24 @@ class OpenAIServingChat(OpenAIServing):
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids) previous_num_tokens[i] = len(output.token_ids)
if request.tool_choice and type(
request.tool_choice
) is ChatCompletionNamedToolChoiceParam:
delta_message = DeltaMessage(tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text))
])
else:
delta_message = DeltaMessage(content=delta_text)
if output.finish_reason is None: if output.finish_reason is None:
# Send token-by-token response for each request.n # Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=DeltaMessage(content=delta_text), delta=delta_message,
logprobs=logprobs, logprobs=logprobs,
finish_reason=None) finish_reason=None)
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
...@@ -311,20 +416,17 @@ class OpenAIServingChat(OpenAIServing): ...@@ -311,20 +416,17 @@ class OpenAIServingChat(OpenAIServing):
created=created_time, created=created_time,
choices=[choice_data], choices=[choice_data],
model=model_name) model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
else: else:
# Send the finish response for each request.n only once # Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids) prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens +
previous_num_tokens[i],
)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=DeltaMessage(content=delta_text), delta=delta_message,
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
stop_reason=output.stop_reason) stop_reason=output.stop_reason)
...@@ -334,12 +436,32 @@ class OpenAIServingChat(OpenAIServing): ...@@ -334,12 +436,32 @@ class OpenAIServingChat(OpenAIServing):
created=created_time, created=created_time,
choices=[choice_data], choices=[choice_data],
model=model_name) model=model_name)
if final_usage is not None: if (request.stream_options
chunk.usage = final_usage and request.stream_options.include_usage):
data = chunk.model_dump_json(exclude_unset=True, chunk.usage = None
exclude_none=True) data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
finish_reason_sent[i] = True finish_reason_sent[i] = True
if (request.stream_options
and request.stream_options.include_usage):
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
)
final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e)) data = self.create_streaming_error_response(str(e))
...@@ -370,20 +492,34 @@ class OpenAIServingChat(OpenAIServing): ...@@ -370,20 +492,34 @@ class OpenAIServingChat(OpenAIServing):
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
for output in final_res.outputs: for output in final_res.outputs:
token_ids = output.token_ids token_ids = output.token_ids
top_logprobs = output.logprobs out_logprobs = output.logprobs
if request.logprobs: if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_chat_logprobs( logprobs = self._create_chat_logprobs(
token_ids=token_ids, token_ids=token_ids,
top_logprobs=top_logprobs, top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs, num_output_top_logprobs=request.top_logprobs,
) )
else: else:
logprobs = None logprobs = None
if request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text))
])
elif not request.tool_choice or request.tool_choice == "none":
message = ChatMessage(role=role, content=output.text)
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=output.index, index=output.index,
message=ChatMessage(role=role, content=output.text), message=message,
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
stop_reason=output.stop_reason) stop_reason=output.stop_reason)
......
...@@ -8,6 +8,7 @@ from fastapi import Request ...@@ -8,6 +8,7 @@ from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs, from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionRequest, CompletionRequest,
...@@ -16,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, ...@@ -16,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
UsageInfo) UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing) OpenAIServing)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -221,7 +221,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -221,7 +221,7 @@ class OpenAIServingCompletion(OpenAIServing):
# only return the prompt # only return the prompt
delta_text = res.prompt delta_text = res.prompt
delta_token_ids = res.prompt_token_ids delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs out_logprobs = res.prompt_logprobs
has_echoed[i] = True has_echoed[i] = True
elif (request.echo and request.max_tokens > 0 elif (request.echo and request.max_tokens > 0
and not has_echoed[i]): and not has_echoed[i]):
...@@ -229,7 +229,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -229,7 +229,7 @@ class OpenAIServingCompletion(OpenAIServing):
delta_text = res.prompt + output.text delta_text = res.prompt + output.text
delta_token_ids = (res.prompt_token_ids + delta_token_ids = (res.prompt_token_ids +
output.token_ids) output.token_ids)
top_logprobs = res.prompt_logprobs + (output.logprobs out_logprobs = res.prompt_logprobs + (output.logprobs
or []) or [])
has_echoed[i] = True has_echoed[i] = True
else: else:
...@@ -237,13 +237,15 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -237,13 +237,15 @@ class OpenAIServingCompletion(OpenAIServing):
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[ delta_token_ids = output.token_ids[
previous_num_tokens[i]:] previous_num_tokens[i]:]
top_logprobs = output.logprobs[previous_num_tokens[ out_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None i]:] if output.logprobs else None
if request.logprobs is not None: if request.logprobs is not None:
assert out_logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_completion_logprobs( logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids, token_ids=delta_token_ids,
top_logprobs=top_logprobs, top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]), initial_text_offset=len(previous_texts[i]),
) )
...@@ -264,7 +266,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -264,7 +266,8 @@ class OpenAIServingCompletion(OpenAIServing):
) )
else: else:
final_usage = None final_usage = None
response_json = CompletionStreamResponse(
chunk = CompletionStreamResponse(
id=request_id, id=request_id,
created=created_time, created=created_time,
model=model_name, model=model_name,
...@@ -276,10 +279,27 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -276,10 +279,27 @@ class OpenAIServingCompletion(OpenAIServing):
finish_reason=finish_reason, finish_reason=finish_reason,
stop_reason=stop_reason, stop_reason=stop_reason,
) )
], ])
usage=final_usage, if (request.stream_options
).model_dump_json(exclude_unset=True) and request.stream_options.include_usage):
chunk.usage = None
response_json = chunk.model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if (request.stream_options
and request.stream_options.include_usage):
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=final_usage,
)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e)) data = self.create_streaming_error_response(str(e))
...@@ -307,25 +327,23 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -307,25 +327,23 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids token_ids = prompt_token_ids
top_logprobs = prompt_logprobs out_logprobs = prompt_logprobs
output_text = prompt_text output_text = prompt_text
elif request.echo and request.max_tokens > 0: elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids token_ids = prompt_token_ids + output.token_ids
top_logprobs = (prompt_logprobs + output.logprobs out_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None) if request.logprobs is not None else None)
output_text = prompt_text + output.text output_text = prompt_text + output.text
else: else:
token_ids = output.token_ids token_ids = output.token_ids
top_logprobs = output.logprobs out_logprobs = output.logprobs
output_text = output.text output_text = output.text
if request.logprobs is not None: if request.logprobs is not None:
assert top_logprobs is not None, ( assert out_logprobs is not None, "Did not output logprobs"
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_completion_logprobs( logprobs = self._create_completion_logprobs(
token_ids=token_ids, token_ids=token_ids,
top_logprobs=top_logprobs, top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
) )
else: else:
......
...@@ -62,6 +62,7 @@ class OpenAIServing: ...@@ -62,6 +62,7 @@ class OpenAIServing:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ model_cards = [
ModelCard(id=served_model_name, ModelCard(id=served_model_name,
max_model_len=self.max_model_len,
root=self.served_model_names[0], root=self.served_model_names[0],
permission=[ModelPermission()]) permission=[ModelPermission()])
for served_model_name in self.served_model_names for served_model_name in self.served_model_names
...@@ -130,7 +131,8 @@ class OpenAIServing: ...@@ -130,7 +131,8 @@ class OpenAIServing:
prompt_ids: Optional[List[int]] = None, prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int, truncate_prompt_tokens: Optional[Annotated[int,
Field(ge=1)]] = None, Field(ge=1)]] = None,
add_special_tokens: bool = True) -> Tuple[List[int], str]: add_special_tokens: Optional[bool] = True
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids): if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.") raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids): if (prompt and prompt_ids):
...@@ -138,11 +140,12 @@ class OpenAIServing: ...@@ -138,11 +140,12 @@ class OpenAIServing:
"Only one of prompt or prompt_ids should be provided.") "Only one of prompt or prompt_ids should be provided.")
if prompt_ids is None: if prompt_ids is None:
# When using OpenAIServingChat for chat completions, the # When using OpenAIServingChat for chat completions, for
# special tokens (e.g., BOS) have already been added by the # most models the special tokens (e.g., BOS) have already
# chat template. Therefore, we do not need to add them again. # been added by the chat template. Therefore, we do not
# Set add_special_tokens to False to avoid adding the BOS tokens # need to add them again.
# again. # Set add_special_tokens to False (by default) to avoid
# adding the BOS tokens again.
tokenizer_kwargs: Dict[str, Any] = { tokenizer_kwargs: Dict[str, Any] = {
"add_special_tokens": add_special_tokens "add_special_tokens": add_special_tokens
} }
......
...@@ -29,10 +29,10 @@ if TYPE_CHECKING: ...@@ -29,10 +29,10 @@ if TYPE_CHECKING:
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None NVCC_THREADS: Optional[str] = None
VLLM_BUILD_WITH_NEURON: bool = False
VLLM_USE_PRECOMPILED: bool = False VLLM_USE_PRECOMPILED: bool = False
VLLM_INSTALL_PUNICA_KERNELS: bool = False VLLM_INSTALL_PUNICA_KERNELS: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None CMAKE_BUILD_TYPE: Optional[str] = None
...@@ -62,10 +62,6 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -62,10 +62,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"NVCC_THREADS": "NVCC_THREADS":
lambda: os.getenv("NVCC_THREADS", None), lambda: os.getenv("NVCC_THREADS", None),
# If set, vllm will build with Neuron support
"VLLM_BUILD_WITH_NEURON":
lambda: bool(os.environ.get("VLLM_BUILD_WITH_NEURON", False)),
# If set, vllm will use precompiled binaries (*.so) # If set, vllm will use precompiled binaries (*.so)
"VLLM_USE_PRECOMPILED": "VLLM_USE_PRECOMPILED":
lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")),
...@@ -99,6 +95,9 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -99,6 +95,9 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""),
# used in distributed environment to manually set the communication port # used in distributed environment to manually set the communication port
# Note: if VLLM_PORT is set, and some code asks for multiple ports, the
# VLLM_PORT will be used as the first port, and the rest will be generated
# by incrementing the VLLM_PORT value.
# '0' is used to make mypy happy # '0' is used to make mypy happy
'VLLM_PORT': 'VLLM_PORT':
lambda: int(os.getenv('VLLM_PORT', '0')) lambda: int(os.getenv('VLLM_PORT', '0'))
...@@ -213,6 +212,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -213,6 +212,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Both spawn and fork work # Both spawn and fork work
"VLLM_WORKER_MULTIPROC_METHOD": "VLLM_WORKER_MULTIPROC_METHOD":
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
# Timeout for fetching images when serving multimodal models
# Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -19,10 +19,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -19,10 +19,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"""Python multiprocessing-based multi-GPU executor""" """Python multiprocessing-based multi-GPU executor"""
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert (
not self.speculative_config
), "Speculative decoding not yet supported for MultiProcGPU backend."
# Create the parallel GPU workers. # Create the parallel GPU workers.
world_size = self.parallel_config.tensor_parallel_size world_size = self.parallel_config.tensor_parallel_size
...@@ -34,6 +30,9 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -34,6 +30,9 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
from torch.cuda import device_count from torch.cuda import device_count
assert world_size <= device_count(), ( assert world_size <= device_count(), (
"please set tensor_parallel_size to less than max local gpu count") "please set tensor_parallel_size to less than max local gpu count")
...@@ -43,6 +42,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -43,6 +42,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if world_size == 1: if world_size == 1:
self.workers = [] self.workers = []
self.worker_monitor = None
else: else:
result_handler = ResultHandler() result_handler = ResultHandler()
self.workers = [ self.workers = [
...@@ -124,7 +124,8 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -124,7 +124,8 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
def check_health(self) -> None: def check_health(self) -> None:
"""Raises an error if engine is unhealthy.""" """Raises an error if engine is unhealthy."""
if not self.worker_monitor.is_alive(): if self.worker_monitor is not None and not self.worker_monitor.is_alive(
):
raise RuntimeError("Worker processes are not running") raise RuntimeError("Worker processes are not running")
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
......
...@@ -65,10 +65,11 @@ def _set_future_result(future: Union[ResultFuture, asyncio.Future], ...@@ -65,10 +65,11 @@ def _set_future_result(future: Union[ResultFuture, asyncio.Future],
future.set_result(result) future.set_result(result)
return return
loop = future.get_loop() loop = future.get_loop()
if result.exception is not None: if not loop.is_closed():
loop.call_soon_threadsafe(future.set_exception, result.exception) if result.exception is not None:
else: loop.call_soon_threadsafe(future.set_exception, result.exception)
loop.call_soon_threadsafe(future.set_result, result.value) else:
loop.call_soon_threadsafe(future.set_result, result.value)
class ResultHandler(threading.Thread): class ResultHandler(threading.Thread):
......
...@@ -293,23 +293,6 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -293,23 +293,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
]) ])
return forward_dag.experimental_compile() return forward_dag.experimental_compile()
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def _check_if_any_actor_is_dead(self):
if not self.workers:
return
dead_actors = []
for actor in self.workers:
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
if actor_state["State"] == "DEAD":
dead_actors.append(actor)
if dead_actors:
raise RuntimeError("At least one Worker is dead. "
f"Dead Workers: {dead_actors}. ")
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
......
...@@ -215,19 +215,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -215,19 +215,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None) -> None:
lora_vocab_start_idx = self.base_layer.org_vocab_size if self.base_layer.num_added_embeddings_per_partition > 0:
weights_idx = None
if self.base_layer.vocab_end_index > lora_vocab_start_idx:
# We can start adding lora weights # We can start adding lora weights
weights_idx = max( self.embeddings_weights = self.base_layer.weight.data[
lora_vocab_start_idx - self.base_layer.vocab_start_index, 0) self.base_layer.num_org_embeddings_per_partition:self.
self.embeddings_slice = (self.base_layer.vocab_start_index - base_layer.num_org_embeddings_per_partition +
self.base_layer.org_vocab_size + self.base_layer.num_added_embeddings_per_partition]
weights_idx, self.embeddings_slice = (
self.base_layer.vocab_end_index - self.base_layer.shard_indices.added_vocab_start_index -
self.base_layer.org_vocab_size) self.base_layer.org_vocab_size,
self.embeddings_weights = self.base_layer.weight.data[weights_idx:] self.base_layer.shard_indices.added_vocab_end_index -
self.embeddings_weights.fill_(0) self.base_layer.org_vocab_size)
self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition:].fill_(0)
else: else:
self.embeddings_slice = None self.embeddings_slice = None
self.embeddings_weights = None self.embeddings_weights = None
...@@ -1025,19 +1025,31 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -1025,19 +1025,31 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
class LogitsProcessorWithLoRA(BaseLayerWithLoRA): class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
"""
LoRA wrapper for LogitsProcessor, with extra logic to handle the
application of the LoRA adapter and added LoRA vocabulary.
Args:
base_layer: LogitsProcessor layer
hidden_size: hidden size of the model
dtype: data type of the model
device: device of the model
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
received from base_layer.get_sharded_to_full_mapping(). If None,
no reindexing will be done.
"""
def __init__( def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
self, dtype: torch.dtype, device: torch.device,
base_layer: LogitsProcessor, sharded_to_full_mapping: Optional[List[int]]) -> None:
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.sharded_to_full_mapping = sharded_to_full_mapping
@property @property
def logits_as_input(self): def logits_as_input(self):
...@@ -1098,6 +1110,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1098,6 +1110,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
) )
if self.sharded_to_full_mapping is not None:
self.sharded_to_full_mapping_gpu = torch.tensor(
self.sharded_to_full_mapping,
device=self.device,
dtype=torch.long)
else:
self.sharded_to_full_mapping_gpu = None
# Lazily initialized. # Lazily initialized.
self.indices: torch.Tensor self.indices: torch.Tensor
self.indices_len: List[int] self.indices_len: List[int]
...@@ -1154,6 +1173,25 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1154,6 +1173,25 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
if logits is None: if logits is None:
return None return None
if self.sharded_to_full_mapping_gpu is not None:
# Reindex full logits tensor to ensure 1:1 mapping between
# index and token_id
# Example for:
# org_vocab_size = 4
# added_vocab_size = 2
# pad_to_size = 8
# tp_size = 2
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
# Therefore, the mapping is expected to be:
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
# we get:
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
logits = logits[:, self.sharded_to_full_mapping_gpu]
lora_logits = torch.empty( lora_logits = torch.empty(
self.embeddings_tensors.shape[0] + 1, self.embeddings_tensors.shape[0] + 1,
self.embeddings_tensors.shape[1], self.embeddings_tensors.shape[1],
......
...@@ -4,16 +4,21 @@ from typing import Optional ...@@ -4,16 +4,21 @@ from typing import Optional
import torch import torch
from vllm import _custom_ops as ops
def _check_punica_support():
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
return
def _raise_import_error(e):
if torch.cuda.get_device_capability() < (8, 0): if torch.cuda.get_device_capability() < (8, 0):
raise ImportError( raise ImportError(
"punica LoRA kernels require compute capability >= 8.0") from e "punica LoRA kernels require compute capability >= 8.0")
else: else:
raise ImportError( raise ImportError(
"punica LoRA kernels could not be imported. If you built vLLM " "punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.") from e "was set.")
def bgmv( def bgmv(
...@@ -41,12 +46,9 @@ def bgmv( ...@@ -41,12 +46,9 @@ def bgmv(
layer_idx: Layer index of the weight matrices. layer_idx: Layer index of the weight matrices.
scale: Scaling factor. scale: Scaling factor.
""" """
try: _check_punica_support()
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
...@@ -75,11 +77,9 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, ...@@ -75,11 +77,9 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
y_offset: Offset to apply to the starting column of y. y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice. y_slice_size: Size of the y column slice.
""" """
try: _check_punica_support()
import vllm._punica_C as punica_kernels
except ImportError as e: ops.dispatch_bgmv_low_level(
_raise_import_error(e)
punica_kernels.dispatch_bgmv_low_level(
y, y,
x, x,
w_t_all, w_t_all,
...@@ -122,10 +122,7 @@ def add_lora(y: torch.Tensor, ...@@ -122,10 +122,7 @@ def add_lora(y: torch.Tensor,
scale: Scaling factor. scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer. buffer: Optional. Shape: `[B, R]`. Temporary buffer.
""" """
try: _check_punica_support()
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
r = wb_t_all.size(-1) r = wb_t_all.size(-1)
if buffer is None: if buffer is None:
...@@ -135,9 +132,8 @@ def add_lora(y: torch.Tensor, ...@@ -135,9 +132,8 @@ def add_lora(y: torch.Tensor,
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
device=x.device) device=x.device)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale)
scale)
def add_lora_slice(y: torch.Tensor, def add_lora_slice(y: torch.Tensor,
...@@ -176,10 +172,7 @@ def add_lora_slice(y: torch.Tensor, ...@@ -176,10 +172,7 @@ def add_lora_slice(y: torch.Tensor,
y_offset: Offset to apply to the starting column of y. y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice. y_slice_size: Size of the y column slice.
""" """
try: _check_punica_support()
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
r = wb_t_all.size(-1) r = wb_t_all.size(-1)
if buffer is None: if buffer is None:
...@@ -189,7 +182,7 @@ def add_lora_slice(y: torch.Tensor, ...@@ -189,7 +182,7 @@ def add_lora_slice(y: torch.Tensor,
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
device=x.device) device=x.device)
punica_kernels.dispatch_bgmv_low_level( ops.dispatch_bgmv_low_level(
buffer, buffer,
x, x,
wa_t_all, wa_t_all,
...@@ -200,7 +193,7 @@ def add_lora_slice(y: torch.Tensor, ...@@ -200,7 +193,7 @@ def add_lora_slice(y: torch.Tensor,
buffer.size(1), buffer.size(1),
0, 0,
) )
punica_kernels.dispatch_bgmv_low_level( ops.dispatch_bgmv_low_level(
y, y,
buffer, buffer,
wb_t_all, wb_t_all,
......
...@@ -67,7 +67,8 @@ def from_layer_logits_processor( ...@@ -67,7 +67,8 @@ def from_layer_logits_processor(
model_config: Optional[PretrainedConfig] = None, model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA: ) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device) lm_head.weight.dtype, lm_head.weight.device,
lm_head.get_sharded_to_full_mapping())
ret.create_lora_weights(max_loras, lora_config, model_config) ret.create_lora_weights(max_loras, lora_config, model_config)
return ret return ret
...@@ -93,13 +94,12 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: ...@@ -93,13 +94,12 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
is_lora_a whether the tensor is lora_a or lora_b. is_lora_a whether the tensor is lora_a or lora_b.
""" """
parts = name.split(".") parts = name.split(".")
assert parts[0] == "base_model"
assert parts[1] == "model"
if parts[-1] == "weight":
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" if parts[-1] == "weight":
if parts[-2] == "lora_A" or parts[-2] == "lora_B":
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported format") raise ValueError(f"{name} is unsupported LoRA weight")
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
...@@ -42,7 +42,8 @@ class AbstractWorkerLoRAManager(ABC): ...@@ -42,7 +42,8 @@ class AbstractWorkerLoRAManager(ABC):
yield yield
self._cached_dummy_lora = False self._cached_dummy_lora = False
@abstractproperty @property
@abstractmethod
def is_enabled(self) -> bool: def is_enabled(self) -> bool:
... ...
......
import torch.nn as nn
from vllm.utils import is_cpu, is_hip
class CustomOp(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self._forward_method = self.dispatch_forward()
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
"""
raise NotImplementedError
def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
def forward_hip(self, *args, **kwargs):
# By default, we assume that HIP ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
def forward_xpu(self, *args, **kwargs):
# By default, we assume that XPU ops are compatible with CUDA ops.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_cuda(*args, **kwargs)
def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
def forward_tpu(self, *args, **kwargs):
# By default, we assume that TPU ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def forward_gaudi(self, *args, **kwargs):
# By default, we assume that Gaudi ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if is_hip():
return self.forward_hip
elif is_cpu():
return self.forward_cpu
else:
return self.forward_cuda
from typing import Optional, Union from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (
CompletionRequest) ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
get_lm_format_enforcer_guided_decoding_logits_processor) get_lm_format_enforcer_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_decoding import ( from vllm.model_executor.guided_decoding.outlines_decoding import (
...@@ -13,6 +14,8 @@ async def get_guided_decoding_logits_processor( ...@@ -13,6 +14,8 @@ async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest, guided_decoding_backend: str, request: Union[CompletionRequest,
ChatCompletionRequest], ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]: tokenizer) -> Optional[LogitsProcessor]:
request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines': if guided_decoding_backend == 'outlines':
return await get_outlines_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor(
request, tokenizer) request, tokenizer)
...@@ -23,3 +26,26 @@ async def get_guided_decoding_logits_processor( ...@@ -23,3 +26,26 @@ async def get_guided_decoding_logits_processor(
raise ValueError( raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. " f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'") "Must be one of 'outlines, 'lm-format-enforcer'")
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
if type(request) is CompletionRequest:
return request
# user has chosen to not use any tool
if request.tool_choice == "none":
return request
# user has chosen to use a named tool
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = request.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in request.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
request.guided_json = tool.parameters
return request
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment