Unverified Commit dd7e8f5f authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

refactor complemention api for readability (#2499)

parent d2a68364
......@@ -88,6 +88,16 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
messages = [{
......
......@@ -6,6 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
class ErrorResponse(BaseModel):
......@@ -78,6 +79,26 @@ class ChatCompletionRequest(BaseModel):
repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0
def to_sampling_params(self) -> SamplingParams:
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
min_p=self.min_p,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
)
class CompletionRequest(BaseModel):
model: str
......@@ -107,6 +128,30 @@ class CompletionRequest(BaseModel):
repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens),
)
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
......
......@@ -11,7 +11,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger = init_logger(__name__)
......@@ -60,32 +59,11 @@ class OpenAIServingChat(OpenAIServing):
f"Error in applying chat template from request: {str(e)}")
return self.create_error_response(str(e))
token_ids, error_check_ret = await self._check_length(request,
prompt=prompt)
if error_check_ret is not None:
return error_check_ret
request_id = f"cmpl-{random_uuid()}"
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
top_p=request.top_p,
min_p=request.min_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
sampling_params = request.to_sampling_params()
except ValueError as e:
return self.create_error_response(str(e))
......
import asyncio
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine
......@@ -104,27 +104,30 @@ class OpenAIServing:
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
async def _check_length(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None
) -> Tuple[List[int], Optional[ErrorResponse]]:
assert (not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided."
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids
token_num = len(input_ids)
if request.max_tokens is None:
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
return input_ids, self.create_error_response(
raise ValueError(
f"This model's maximum context length is {self.max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
else:
return input_ids, None
return input_ids
......@@ -163,7 +163,7 @@ def prepare_hf_model_weights(
use_safetensors = True
break
logger.info(f"Downloading model weights {allow_patterns}")
logger.info(f"Using model weights format {allow_patterns}")
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment