"vscode:/vscode.git/clone" did not exist on "3c4cebf751a6d2ff9ada2f8234bab17ba7283e09"
Unverified Commit 1d7c940d authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

Add option to completion API to truncate prompt tokens (#3144)

parent cfaf49a1
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, conint, model_validator
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -229,6 +229,7 @@ class CompletionRequest(BaseModel): ...@@ -229,6 +229,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = 0 min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[conint(ge=1)] = None
# doc: end-completion-sampling-params # doc: end-completion-sampling-params
# doc: begin-completion-extra-params # doc: begin-completion-extra-params
...@@ -309,6 +310,7 @@ class CompletionRequest(BaseModel): ...@@ -309,6 +310,7 @@ class CompletionRequest(BaseModel):
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
) )
@model_validator(mode="before") @model_validator(mode="before")
......
...@@ -137,10 +137,16 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -137,10 +137,16 @@ class OpenAIServingCompletion(OpenAIServing):
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if prompt_is_tokens: if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize( input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt) request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else: else:
input_ids = self._validate_prompt_and_tokenize( input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt) request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
generators.append( generators.append(
self.engine.generate(prompt, self.engine.generate(prompt,
......
...@@ -4,6 +4,8 @@ from dataclasses import dataclass ...@@ -4,6 +4,8 @@ from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import conint
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse, CompletionRequest, ErrorResponse,
...@@ -66,7 +68,8 @@ class OpenAIServing: ...@@ -66,7 +68,8 @@ class OpenAIServing:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
engine_model_config.tokenizer, engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode, tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code) trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
...@@ -164,15 +167,26 @@ class OpenAIServing: ...@@ -164,15 +167,26 @@ class OpenAIServing:
self, self,
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None, prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]: prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None
) -> List[int]:
if not (prompt or prompt_ids): if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.") raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids): if (prompt and prompt_ids):
raise ValueError( raise ValueError(
"Only one of prompt or prompt_ids should be provided.") "Only one of prompt or prompt_ids should be provided.")
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( if prompt_ids is None:
prompt).input_ids tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
"truncation": True,
"max_length": truncate_prompt_tokens,
}
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else:
input_ids = prompt_ids
token_num = len(input_ids) token_num = len(input_ids)
if request.max_tokens is None: if request.max_tokens is None:
......
...@@ -5,6 +5,7 @@ from functools import cached_property ...@@ -5,6 +5,7 @@ from functools import cached_property
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from pydantic import conint
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -94,6 +95,9 @@ class SamplingParams: ...@@ -94,6 +95,9 @@ class SamplingParams:
tokens in the output. Defaults to True. tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on logits_processors: List of functions that modify logits based on
previously generated tokens. previously generated tokens.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
""" """
def __init__( def __init__(
...@@ -123,6 +127,7 @@ class SamplingParams: ...@@ -123,6 +127,7 @@ class SamplingParams:
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None, logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
) -> None: ) -> None:
self.n = n self.n = n
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n
...@@ -160,6 +165,7 @@ class SamplingParams: ...@@ -160,6 +165,7 @@ class SamplingParams:
self.spaces_between_special_tokens = spaces_between_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
self._verify_beam_search() self._verify_beam_search()
...@@ -216,6 +222,10 @@ class SamplingParams: ...@@ -216,6 +222,10 @@ class SamplingParams:
if self.prompt_logprobs is not None and self.prompt_logprobs < 0: if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got " raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.") f"{self.prompt_logprobs}.")
if (self.truncate_prompt_tokens is not None
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
if self.stop and not self.detokenize: if self.stop and not self.detokenize:
raise ValueError( raise ValueError(
"stop strings are only supported when detokenize is True. " "stop strings are only supported when detokenize is True. "
...@@ -300,4 +310,5 @@ class SamplingParams: ...@@ -300,4 +310,5 @@ class SamplingParams:
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, " f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens=" "spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})") f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment