Unverified Commit baaedfdb authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable following imports for entrypoints (#7248)


Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarFei <dfdfcai4@gmail.com>
parent 45066412
...@@ -23,8 +23,8 @@ class AsyncEngineRPCServer: ...@@ -23,8 +23,8 @@ class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs, def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, rpc_path: str): usage_context: UsageContext, rpc_path: str):
# Initialize engine first. # Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, self.engine = AsyncLLMEngine.from_engine_args(
usage_context) async_engine_args, usage_context=usage_context)
# Initialize context. # Initialize context.
self.context = zmq.asyncio.Context() self.context = zmq.asyncio.Context()
...@@ -39,7 +39,7 @@ class AsyncEngineRPCServer: ...@@ -39,7 +39,7 @@ class AsyncEngineRPCServer:
self.context.destroy() self.context.destroy()
self.engine.shutdown_background_loop() self.engine.shutdown_background_loop()
# Clear the engine reference so that it can be GC'ed. # Clear the engine reference so that it can be GC'ed.
self.engine = None del self.engine
async def get_model_config(self, identity): async def get_model_config(self, identity):
"""Send the ModelConfig""" """Send the ModelConfig"""
......
import asyncio import asyncio
import time import time
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import Union
from fastapi import Request from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
...@@ -24,13 +23,14 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -24,13 +23,14 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing, OpenAIServing,
PromptAdapterPath) PromptAdapterPath)
from vllm.inputs import PromptInputs from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid from vllm.utils import iterate_with_cancellation, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -67,9 +67,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -67,9 +67,9 @@ class OpenAIServingChat(OpenAIServing):
async def create_chat_completion( async def create_chat_completion(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
raw_request: Optional[Request] = None raw_request: Optional[Request] = None,
) -> Union[ErrorResponse, AsyncGenerator[str, None], ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ChatCompletionResponse]: ErrorResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create See https://platform.openai.com/docs/api-reference/chat/create
...@@ -83,16 +83,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -83,16 +83,6 @@ class OpenAIServingChat(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
if request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid "
f"negative value: {request.prompt_logprobs}")
try: try:
( (
lora_request, lora_request,
...@@ -160,9 +150,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -160,9 +150,8 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
engine_inputs: PromptInputs = { engine_inputs = TokensPrompt(
"prompt_token_ids": prompt_inputs["prompt_token_ids"], prompt_token_ids=prompt_inputs["prompt_token_ids"])
}
if mm_data is not None: if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data engine_inputs["multi_modal_data"] = mm_data
...@@ -214,11 +203,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -214,11 +203,11 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput],
request_id: str, request_id: str,
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True first_iteration = True
# Send response for each token for each request.n (index) # Send response for each token for each request.n (index)
...@@ -438,7 +427,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -438,7 +427,7 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput],
request_id: str, request_id: str,
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]: ) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
...@@ -523,7 +512,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -523,7 +512,7 @@ class OpenAIServingChat(OpenAIServing):
def _get_top_logprobs( def _get_top_logprobs(
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]: tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
return [ return [
ChatCompletionLogProb(token=(token := self._get_decoded_token( ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1], p[1],
...@@ -541,12 +530,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -541,12 +530,11 @@ class OpenAIServingChat(OpenAIServing):
self, self,
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
num_output_top_logprobs: Optional[int] = None, num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs: ) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs.""" """Create OpenAI-style logprobs."""
logprobs_content: List[ChatCompletionLogProbsContent] = []
logprobs_content = []
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
...@@ -554,23 +542,32 @@ class OpenAIServingChat(OpenAIServing): ...@@ -554,23 +542,32 @@ class OpenAIServingChat(OpenAIServing):
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids: if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
token=token, token=token,
bytes=list(token.encode("utf-8", errors="replace")))) bytes=list(token.encode("utf-8", errors="replace")),
))
else: else:
step_token = step_top_logprobs[token_id]
step_decoded = step_token.decoded_token
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
token=self._get_decoded_token( token=self._get_decoded_token(
step_top_logprobs[token_id], token_id, tokenizer, step_token,
self.return_tokens_as_token_ids), token_id,
logprob=max(step_top_logprobs[token_id].logprob, tokenizer,
-9999.0), self.return_tokens_as_token_ids,
bytes=list( ),
step_top_logprobs[token_id].decoded_token.encode( logprob=max(step_token.logprob, -9999.0),
"utf-8", errors="replace")), bytes=None if step_decoded is None else list(
step_decoded.encode("utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs( top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs, step_top_logprobs,
tokenizer))) num_output_top_logprobs,
tokenizer,
),
))
return ChatCompletionLogProbs(content=logprobs_content) return ChatCompletionLogProbs(content=logprobs_content)
...@@ -3,10 +3,9 @@ import time ...@@ -3,10 +3,9 @@ import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional) Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple, cast from typing import Tuple, Union, cast
from fastapi import Request from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
...@@ -19,7 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, ...@@ -19,7 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
UsageInfo) ErrorResponse, UsageInfo)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing, OpenAIServing,
...@@ -29,6 +28,7 @@ from vllm.outputs import RequestOutput ...@@ -29,6 +28,7 @@ from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -60,8 +60,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -60,8 +60,11 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids) return_tokens_as_token_ids=return_tokens_as_token_ids)
async def create_completion(self, request: CompletionRequest, async def create_completion(
raw_request: Request): self,
request: CompletionRequest,
raw_request: Request,
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create See https://platform.openai.com/docs/api-reference/completions/create
...@@ -84,15 +87,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -84,15 +87,6 @@ class OpenAIServingCompletion(OpenAIServing):
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time()) created_time = int(time.time())
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
elif request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid negative "
f"value: {request.prompt_logprobs}")
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = [] generators: List[AsyncGenerator[RequestOutput, None]] = []
try: try:
...@@ -153,8 +147,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -153,8 +147,7 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[ result_generator = merge_async_iterators(
int, RequestOutput]] = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected) *generators, is_cancelled=raw_request.is_disconnected)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
...@@ -227,7 +220,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -227,7 +220,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int, created_time: int,
model_name: str, model_name: str,
num_prompts: int, num_prompts: int,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts previous_texts = [""] * num_choices * num_prompts
...@@ -236,6 +229,13 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -236,6 +229,13 @@ class OpenAIServingCompletion(OpenAIServing):
try: try:
async for prompt_idx, res in result_generator: async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt
delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]]
for output in res.outputs: for output in res.outputs:
i = output.index + prompt_idx * num_choices i = output.index + prompt_idx * num_choices
...@@ -244,19 +244,25 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -244,19 +244,25 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
assert prompt_text is not None
# only return the prompt # only return the prompt
delta_text = res.prompt delta_text = prompt_text
delta_token_ids = res.prompt_token_ids delta_token_ids = prompt_token_ids
out_logprobs = res.prompt_logprobs out_logprobs = prompt_logprobs
has_echoed[i] = True has_echoed[i] = True
elif (request.echo and request.max_tokens > 0 elif (request.echo and request.max_tokens > 0
and not has_echoed[i]): and not has_echoed[i]):
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token # echo the prompt and first token
delta_text = res.prompt + output.text delta_text = prompt_text + output.text
delta_token_ids = (res.prompt_token_ids + delta_token_ids = [
output.token_ids) *prompt_token_ids, *output.token_ids
out_logprobs = res.prompt_logprobs + (output.logprobs ]
or []) out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
has_echoed[i] = True has_echoed[i] = True
else: else:
# return just the delta # return just the delta
...@@ -301,7 +307,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -301,7 +307,7 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage): and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None): or output.finish_reason is not None):
prompt_tokens = len(res.prompt_token_ids) prompt_tokens = len(prompt_token_ids)
completion_tokens = len(output.token_ids) completion_tokens = len(output.token_ids)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
...@@ -342,7 +348,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -342,7 +348,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> CompletionResponse: ) -> CompletionResponse:
choices: List[CompletionResponseChoice] = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
...@@ -353,16 +359,31 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -353,16 +359,31 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = final_res.prompt_logprobs prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
Logprob]]]]
for output in final_res.outputs: for output in final_res.outputs:
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
assert prompt_text is not None
token_ids = prompt_token_ids token_ids = prompt_token_ids
out_logprobs = prompt_logprobs out_logprobs = prompt_logprobs
output_text = prompt_text output_text = prompt_text
elif request.echo and request.max_tokens > 0: elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + list(output.token_ids) assert prompt_text is not None
out_logprobs = (prompt_logprobs + output.logprobs token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is not None else None)
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text output_text = prompt_text + output.text
else: else:
token_ids = output.token_ids token_ids = output.token_ids
...@@ -413,7 +434,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -413,7 +434,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int, num_output_top_logprobs: int,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
initial_text_offset: int = 0, initial_text_offset: int = 0,
) -> CompletionLogProbs: ) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API.""" """Create logprobs for OpenAI Completion API."""
...@@ -430,17 +451,21 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -430,17 +451,21 @@ class OpenAIServingCompletion(OpenAIServing):
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids: if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
out_tokens.append(token) out_tokens.append(token)
out_token_logprobs.append(None) out_token_logprobs.append(None)
out_top_logprobs.append(None) out_top_logprobs.append(None)
else: else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token( token = self._get_decoded_token(
step_top_logprobs[token_id], step_token,
token_id, token_id,
tokenizer, tokenizer,
return_as_token_id=self.return_tokens_as_token_ids) return_as_token_id=self.return_tokens_as_token_ids,
token_logprob = max(step_top_logprobs[token_id].logprob, )
-9999.0) token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token) out_tokens.append(token)
out_token_logprobs.append(token_logprob) out_token_logprobs.append(token_logprob)
......
import asyncio import asyncio
import base64 import base64
import time import time
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple, from typing import AsyncGenerator, List, Literal, Optional, Union, cast
Union, cast)
import numpy as np import numpy as np
from fastapi import Request from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
...@@ -16,7 +16,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingRequest, ...@@ -16,7 +16,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
ErrorResponse, UsageInfo) ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -24,18 +24,28 @@ logger = init_logger(__name__) ...@@ -24,18 +24,28 @@ logger = init_logger(__name__)
TypeTokenIDs = List[int] TypeTokenIDs = List[int]
def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
return output.embedding
elif encoding_format == "base64":
embedding_bytes = np.array(output.embedding).tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)
def request_output_to_embedding_response( def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str, final_res_batch: List[EmbeddingRequestOutput], request_id: str,
created_time: int, model_name: str, created_time: int, model_name: str,
encoding_format: str) -> EmbeddingResponse: encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = [] data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch): for idx, final_res in enumerate(final_res_batch):
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding embedding = _get_embedding(final_res.outputs, encoding_format)
if encoding_format == "base64":
embedding_bytes = np.array(embedding).tobytes()
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data) data.append(embedding_data)
...@@ -76,8 +86,8 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -76,8 +86,8 @@ class OpenAIServingEmbedding(OpenAIServing):
async def create_embedding( async def create_embedding(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
raw_request: Optional[Request] = None raw_request: Optional[Request] = None,
) -> Union[ErrorResponse, EmbeddingResponse]: ) -> Union[EmbeddingResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create See https://platform.openai.com/docs/api-reference/embeddings/create
...@@ -89,8 +99,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -89,8 +99,7 @@ class OpenAIServingEmbedding(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
encoding_format = (request.encoding_format encoding_format = request.encoding_format
if request.encoding_format else "float")
if request.dimensions is not None: if request.dimensions is not None:
return self.create_error_response( return self.create_error_response(
"dimensions is currently not supported") "dimensions is currently not supported")
...@@ -145,11 +154,10 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -145,11 +154,10 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[ result_generator = merge_async_iterators(
int, EmbeddingRequestOutput]] = merge_async_iterators(
*generators, *generators,
is_cancelled=raw_request.is_disconnected is_cancelled=raw_request.is_disconnected if raw_request else None,
if raw_request else None) )
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]] final_res_batch: List[Optional[EmbeddingRequestOutput]]
...@@ -175,7 +183,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -175,7 +183,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return response return response
def _check_embedding_mode(self, embedding_mode: bool): def _check_embedding_mode(self, embedding_mode: bool) -> bool:
if not embedding_mode: if not embedding_mode:
logger.warning( logger.warning(
"embedding_mode is False. Embedding API will not work.") "embedding_mode is False. Embedding API will not work.")
......
...@@ -31,7 +31,7 @@ from vllm.pooling_params import PoolingParams ...@@ -31,7 +31,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -153,6 +153,68 @@ class SamplingParams( ...@@ -153,6 +153,68 @@ class SamplingParams(
output_text_buffer_length: int = 0 output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
@staticmethod
def from_optional(
n: Optional[int] = 1,
best_of: Optional[int] = None,
presence_penalty: Optional[float] = 0.0,
frequency_penalty: Optional[float] = 0.0,
repetition_penalty: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 1.0,
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
min_tokens: int = 0,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
detokenize: bool = True,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
) -> "SamplingParams":
return SamplingParams(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=0.0
if presence_penalty is None else presence_penalty,
frequency_penalty=0.0
if frequency_penalty is None else frequency_penalty,
repetition_penalty=1.0
if repetition_penalty is None else repetition_penalty,
temperature=1.0 if temperature is None else temperature,
top_p=1.0 if top_p is None else top_p,
top_k=top_k,
min_p=min_p,
seed=seed,
use_beam_search=use_beam_search,
length_penalty=length_penalty,
early_stopping=early_stopping,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs,
detokenize=detokenize,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.best_of = self.best_of or self.n self.best_of = self.best_of or self.n
if 0 < self.temperature < _MAX_TEMP: if 0 < self.temperature < _MAX_TEMP:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment