Unverified Commit a4113b03 authored by Gabriel Marinho's avatar Gabriel Marinho Committed by GitHub
Browse files

[Platform] Add custom default max tokens (#18557)


Signed-off-by: default avatarGabriel Marinho <gmarinho@ibm.com>
parent 7e1665b0
...@@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias: Optional[dict[str, float]] = None logit_bias: Optional[dict[str, float]] = None
logprobs: Optional[bool] = False logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0 top_logprobs: Optional[int] = 0
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field( max_tokens: Optional[int] = Field(
default=None, default=None,
deprecated= deprecated=
...@@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
} }
def to_beam_search_params( def to_beam_search_params(
self, self, max_tokens: int,
default_max_tokens: int, default_sampling_params: dict) -> BeamSearchParams:
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1 n = self.n if self.n is not None else 1
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
if (temperature := self.temperature) is None: if (temperature := self.temperature) is None:
temperature = default_sampling_params.get( temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
...@@ -465,21 +451,10 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -465,21 +451,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
def to_sampling_params( def to_sampling_params(
self, self,
default_max_tokens: int, max_tokens: int,
logits_processor_pattern: Optional[str], logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None, default_sampling_params: dict,
) -> SamplingParams: ) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
# Default parameters # Default parameters
if (repetition_penalty := self.repetition_penalty) is None: if (repetition_penalty := self.repetition_penalty) is None:
...@@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel):
} }
def to_beam_search_params( def to_beam_search_params(
self, self,
default_max_tokens: int, max_tokens: int,
default_sampling_params: Optional[dict] = None default_sampling_params: Optional[dict] = None,
) -> BeamSearchParams: ) -> BeamSearchParams:
max_tokens = self.max_tokens
if default_sampling_params is None: if default_sampling_params is None:
default_sampling_params = {} default_sampling_params = {}
n = self.n if self.n is not None else 1 n = self.n if self.n is not None else 1
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
if (temperature := self.temperature) is None: if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0) temperature = default_sampling_params.get("temperature", 1.0)
...@@ -928,21 +896,14 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -928,21 +896,14 @@ class CompletionRequest(OpenAIBaseModel):
def to_sampling_params( def to_sampling_params(
self, self,
default_max_tokens: int, max_tokens: int,
logits_processor_pattern: Optional[str], logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None, default_sampling_params: Optional[dict] = None,
) -> SamplingParams: ) -> SamplingParams:
max_tokens = self.max_tokens
if default_sampling_params is None: if default_sampling_params is None:
default_sampling_params = {} default_sampling_params = {}
# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)
# Default parameters # Default parameters
if (repetition_penalty := self.repetition_penalty) is None: if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get( repetition_penalty = default_sampling_params.get(
...@@ -1813,7 +1774,7 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -1813,7 +1774,7 @@ class TranscriptionRequest(OpenAIBaseModel):
self, self,
default_max_tokens: int, default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams: default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None: if default_sampling_params is None:
...@@ -2029,7 +1990,7 @@ class TranslationRequest(OpenAIBaseModel): ...@@ -2029,7 +1990,7 @@ class TranslationRequest(OpenAIBaseModel):
self, self,
default_max_tokens: int, default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams: default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = default_max_tokens max_tokens = default_max_tokens
if default_sampling_params is None: if default_sampling_params is None:
......
...@@ -34,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels ...@@ -34,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall) MistralToolCall)
from vllm.entrypoints.utils import get_max_tokens
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
...@@ -233,15 +234,22 @@ class OpenAIServingChat(OpenAIServing): ...@@ -233,15 +234,22 @@ class OpenAIServingChat(OpenAIServing):
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams] sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]) if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=len(engine_prompt["prompt_token_ids"]),
default_sampling_params=self.default_sampling_params)
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
default_max_tokens, self.default_sampling_params) max_tokens, self.default_sampling_params)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, max_tokens, self.model_config.logits_processor_pattern,
self.model_config.logits_processor_pattern,
self.default_sampling_params) self.default_sampling_params)
self._log_inputs(request_id, self._log_inputs(request_id,
......
...@@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing, ...@@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
is_text_tokens_prompt) is_text_tokens_prompt)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt) is_tokens_prompt)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -160,15 +161,22 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -160,15 +161,22 @@ class OpenAIServingCompletion(OpenAIServing):
input_length = len(engine_prompt["prompt_token_ids"]) input_length = len(engine_prompt["prompt_token_ids"])
else: else:
assert_never(engine_prompt) assert_never(engine_prompt)
default_max_tokens = self.max_model_len - input_length
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params)
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
default_max_tokens, self.default_sampling_params) max_tokens, self.default_sampling_params)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, max_tokens, self.model_config.logits_processor_pattern,
self.model_config.logits_processor_pattern,
self.default_sampling_params) self.default_sampling_params)
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
......
...@@ -5,13 +5,17 @@ import argparse ...@@ -5,13 +5,17 @@ import argparse
import asyncio import asyncio
import functools import functools
import os import os
from typing import Any, Optional import sys
from typing import Any, Optional, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask, BackgroundTasks from starlette.background import BackgroundTask, BackgroundTasks
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -181,7 +185,6 @@ def _validate_truncation_size( ...@@ -181,7 +185,6 @@ def _validate_truncation_size(
def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
subcommand_name: list[str]): subcommand_name: list[str]):
import sys
# Only handle --help=<keyword> for the current subcommand. # Only handle --help=<keyword> for the current subcommand.
# Since subparser_init() runs for all subcommands during CLI setup, # Since subparser_init() runs for all subcommands during CLI setup,
...@@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, ...@@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
print(f"\nNo group or parameter matching '{search_keyword}'") print(f"\nNo group or parameter matching '{search_keyword}'")
print("Tip: use `--help=listgroup` to view all groups.") print("Tip: use `--help=listgroup` to view all groups.")
sys.exit(1) sys.exit(1)
def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest,
CompletionRequest],
input_length: int, default_sampling_params: dict) -> int:
max_tokens = getattr(request, "max_completion_tokens",
None) or request.max_tokens
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)
return min(val
for val in (default_max_tokens, max_tokens, max_output_tokens,
default_sampling_params.get("max_tokens"))
if val is not None)
...@@ -4,6 +4,7 @@ import enum ...@@ -4,6 +4,7 @@ import enum
import os import os
import platform import platform
import random import random
import sys
from datetime import timedelta from datetime import timedelta
from platform import uname from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Union from typing import TYPE_CHECKING, NamedTuple, Optional, Union
...@@ -164,6 +165,9 @@ class Platform: ...@@ -164,6 +165,9 @@ class Platform:
def is_out_of_tree(self) -> bool: def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT return self._enum == PlatformEnum.OOT
def get_max_output_tokens(self, prompt_len: int) -> int:
return sys.maxsize
def is_cuda_alike(self) -> bool: def is_cuda_alike(self) -> bool:
"""Stateless version of [torch.cuda.is_available][].""" """Stateless version of [torch.cuda.is_available][]."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment