Unverified Commit fd85c9f4 authored by Max Wittig's avatar Max Wittig Committed by GitHub
Browse files

[Bugfix][FE]: Always include usage with `--enable-force-include-usage ` (#20983)


Signed-off-by: default avatarMax Wittig <max.wittig@siemens.com>
Signed-off-by: default avatarAntoine Auger <antoineauger@users.noreply.github.com>
Co-authored-by: default avatarAntoine Auger <antoineauger@users.noreply.github.com>
parent d32c611f
...@@ -107,6 +107,7 @@ markers = [ ...@@ -107,6 +107,7 @@ markers = [
"distributed: run this test only in distributed GPU tests", "distributed: run this test only in distributed GPU tests",
"skip_v1: do not run this test with v1", "skip_v1: do not run this test with v1",
"optional: optional tests that are automatically skipped, include --optional to run them", "optional: optional tests that are automatically skipped, include --optional to run them",
"extra_server_args: extra arguments to pass to the server fixture",
] ]
[tool.ty.src] [tool.ty.src]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import openai
import pytest
import pytest_asyncio
from ...utils import RemoteOpenAIServer
@pytest.fixture(scope="module")
def chat_server_with_force_include_usage(request): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"128",
"--enforce-eager",
"--max-num-seqs",
"1",
"--enable-force-include-usage",
"--port",
"55857",
"--gpu-memory-utilization",
"0.2",
]
with RemoteOpenAIServer("Qwen/Qwen3-0.6B", args, auto_port=False) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def chat_client_with_force_include_usage(chat_server_with_force_include_usage):
async with chat_server_with_force_include_usage.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_chat_with_enable_force_include_usage(
chat_client_with_force_include_usage: openai.AsyncOpenAI,
):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
]
stream = await chat_client_with_force_include_usage.chat.completions.create(
model="Qwen/Qwen3-0.6B",
messages=messages,
max_completion_tokens=10,
extra_body=dict(min_tokens=10),
temperature=0.0,
stream=True,
)
last_completion_tokens = 0
async for chunk in stream:
if not len(chunk.choices):
assert chunk.usage.prompt_tokens >= 0
assert (
last_completion_tokens == 0
or chunk.usage.completion_tokens > last_completion_tokens
or (
not chunk.choices
and chunk.usage.completion_tokens == last_completion_tokens
)
)
assert chunk.usage.total_tokens == (
chunk.usage.prompt_tokens + chunk.usage.completion_tokens
)
else:
assert chunk.usage is None
@pytest.fixture(scope="module")
def transcription_server_with_force_include_usage():
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-num-seqs",
"1",
"--enforce-eager",
"--enable-force-include-usage",
"--gpu-memory-utilization",
"0.2",
]
with RemoteOpenAIServer("openai/whisper-large-v3-turbo", args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def transcription_client_with_force_include_usage(
transcription_server_with_force_include_usage,
):
async with (
transcription_server_with_force_include_usage.get_async_client() as async_client
):
yield async_client
@pytest.mark.asyncio
async def test_transcription_with_enable_force_include_usage(
transcription_client_with_force_include_usage, winning_call
):
res = (
await transcription_client_with_force_include_usage.audio.transcriptions.create(
model="openai/whisper-large-v3-turbo",
file=winning_call,
language="en",
temperature=0.0,
stream=True,
timeout=30,
)
)
async for chunk in res:
if not len(chunk.choices):
# final usage sent
usage = chunk.usage
assert isinstance(usage, dict)
assert usage["prompt_tokens"] > 0
assert usage["completion_tokens"] > 0
assert usage["total_tokens"] > 0
else:
assert not hasattr(chunk, "usage")
...@@ -1808,6 +1808,7 @@ async def init_app_state( ...@@ -1808,6 +1808,7 @@ async def init_app_state(
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
) )
if "transcription" in supported_tasks if "transcription" in supported_tasks
else None else None
...@@ -1818,6 +1819,7 @@ async def init_app_state( ...@@ -1818,6 +1819,7 @@ async def init_app_state(
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
) )
if "transcription" in supported_tasks if "transcription" in supported_tasks
else None else None
......
...@@ -104,6 +104,13 @@ def make_arg_parser(parser: FlexibleArgumentParser): ...@@ -104,6 +104,13 @@ def make_arg_parser(parser: FlexibleArgumentParser):
default=False, default=False,
help="If set to True, enable prompt_tokens_details in usage.", help="If set to True, enable prompt_tokens_details in usage.",
) )
parser.add_argument(
"--enable-force-include-usage",
action="store_true",
default=False,
help="If set to True, include usage on every request "
"(even when stream_options is not specified)",
)
return parser return parser
...@@ -361,6 +368,7 @@ async def run_batch( ...@@ -361,6 +368,7 @@ async def run_batch(
chat_template=None, chat_template=None,
chat_template_content_format="auto", chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) )
if "generate" in supported_tasks if "generate" in supported_tasks
else None else None
......
...@@ -58,7 +58,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l ...@@ -58,7 +58,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
...@@ -101,7 +101,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -101,7 +101,6 @@ class OpenAIServingChat(OpenAIServing):
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
) )
...@@ -352,7 +351,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -352,7 +351,6 @@ class OpenAIServingChat(OpenAIServing):
conversation, conversation,
tokenizer, tokenizer,
request_metadata, request_metadata,
enable_force_include_usage=self.enable_force_include_usage,
) )
try: try:
...@@ -518,7 +516,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -518,7 +516,6 @@ class OpenAIServingChat(OpenAIServing):
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
...@@ -596,13 +593,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -596,13 +593,9 @@ class OpenAIServingChat(OpenAIServing):
return return
stream_options = request.stream_options stream_options = request.stream_options
if stream_options: include_usage, include_continuous_usage = should_include_usage(
include_usage = stream_options.include_usage or enable_force_include_usage stream_options, self.enable_force_include_usage
include_continuous_usage = (
include_usage and stream_options.continuous_usage_stats
) )
else:
include_usage, include_continuous_usage = False, False
try: try:
async for res in result_generator: async for res in result_generator:
......
...@@ -27,7 +27,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -27,7 +27,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
...@@ -56,11 +56,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -56,11 +56,11 @@ class OpenAIServingCompletion(OpenAIServing):
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
) )
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.enable_force_include_usage = enable_force_include_usage
if self.default_sampling_params: if self.default_sampling_params:
source = self.model_config.generation_config source = self.model_config.generation_config
source = "model" if source == "auto" else source source = "model" if source == "auto" else source
...@@ -256,7 +256,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -256,7 +256,6 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts=num_prompts, num_prompts=num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
request_metadata=request_metadata, request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage,
) )
# Non-streaming response # Non-streaming response
...@@ -320,7 +319,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -320,7 +319,6 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts: int, num_prompts: int,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts previous_text_lens = [0] * num_choices * num_prompts
...@@ -331,13 +329,9 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -331,13 +329,9 @@ class OpenAIServingCompletion(OpenAIServing):
first_iteration = True first_iteration = True
stream_options = request.stream_options stream_options = request.stream_options
if stream_options: include_usage, include_continuous_usage = should_include_usage(
include_usage = stream_options.include_usage or enable_force_include_usage stream_options, self.enable_force_include_usage
include_continuous_usage = (
include_usage and stream_options.continuous_usage_stats
) )
else:
include_usage, include_continuous_usage = False, False
try: try:
async for prompt_idx, res in result_generator: async for prompt_idx, res in result_generator:
......
...@@ -249,7 +249,6 @@ class OpenAIServing: ...@@ -249,7 +249,6 @@ class OpenAIServing:
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False, log_error_stack: bool = False,
): ):
super().__init__() super().__init__()
...@@ -260,8 +259,6 @@ class OpenAIServing: ...@@ -260,8 +259,6 @@ class OpenAIServing:
self.request_logger = request_logger self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.enable_force_include_usage = enable_force_include_usage
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._apply_mistral_chat_template_async = make_async( self._apply_mistral_chat_template_async = make_async(
apply_mistral_chat_template, executor=self._tokenizer_executor apply_mistral_chat_template, executor=self._tokenizer_executor
......
...@@ -127,7 +127,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -127,7 +127,6 @@ class OpenAIServingResponses(OpenAIServing):
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
) )
......
...@@ -37,6 +37,7 @@ class OpenAIServingTranscription(OpenAISpeechToText): ...@@ -37,6 +37,7 @@ class OpenAIServingTranscription(OpenAISpeechToText):
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False, log_error_stack: bool = False,
enable_force_include_usage: bool = False,
): ):
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
...@@ -45,6 +46,7 @@ class OpenAIServingTranscription(OpenAISpeechToText): ...@@ -45,6 +46,7 @@ class OpenAIServingTranscription(OpenAISpeechToText):
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="transcribe", task_type="transcribe",
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage,
) )
async def create_transcription( async def create_transcription(
...@@ -96,6 +98,7 @@ class OpenAIServingTranslation(OpenAISpeechToText): ...@@ -96,6 +98,7 @@ class OpenAIServingTranslation(OpenAISpeechToText):
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False, log_error_stack: bool = False,
enable_force_include_usage: bool = False,
): ):
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
...@@ -104,6 +107,7 @@ class OpenAIServingTranslation(OpenAISpeechToText): ...@@ -104,6 +107,7 @@ class OpenAIServingTranslation(OpenAISpeechToText):
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate", task_type="translate",
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage,
) )
async def create_translation( async def create_translation(
......
...@@ -58,6 +58,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -58,6 +58,7 @@ class OpenAISpeechToText(OpenAIServing):
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
task_type: Literal["transcribe", "translate"] = "transcribe", task_type: Literal["transcribe", "translate"] = "transcribe",
log_error_stack: bool = False, log_error_stack: bool = False,
enable_force_include_usage: bool = False,
): ):
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
...@@ -74,6 +75,8 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -74,6 +75,8 @@ class OpenAISpeechToText(OpenAIServing):
self.model_config, task_type self.model_config, task_type
) )
self.enable_force_include_usage = enable_force_include_usage
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
if self.default_sampling_params: if self.default_sampling_params:
...@@ -261,9 +264,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -261,9 +264,7 @@ class OpenAISpeechToText(OpenAIServing):
completion_tokens = 0 completion_tokens = 0
num_prompt_tokens = 0 num_prompt_tokens = 0
include_usage = ( include_usage = self.enable_force_include_usage or request.stream_include_usage
request.stream_include_usage if request.stream_include_usage else False
)
include_continuous_usage = ( include_continuous_usage = (
request.stream_continuous_usage_stats request.stream_continuous_usage_stats
if include_usage and request.stream_continuous_usage_stats if include_usage and request.stream_continuous_usage_stats
......
...@@ -14,7 +14,11 @@ from starlette.background import BackgroundTask, BackgroundTasks ...@@ -14,7 +14,11 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
StreamOptions,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -237,3 +241,16 @@ def log_non_default_args(args: Namespace | EngineArgs): ...@@ -237,3 +241,16 @@ def log_non_default_args(args: Namespace | EngineArgs):
) )
logger.info("non-default args: %s", non_default_args) logger.info("non-default args: %s", non_default_args)
def should_include_usage(
stream_options: StreamOptions | None, enable_force_include_usage: bool
) -> tuple[bool, bool]:
if stream_options:
include_usage = stream_options.include_usage or enable_force_include_usage
include_continuous_usage = include_usage and bool(
stream_options.continuous_usage_stats
)
else:
include_usage, include_continuous_usage = enable_force_include_usage, False
return include_usage, include_continuous_usage
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment