"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "aaa8a567dc86fd216b22f26ed489b8cd7aaa901b"
Unverified Commit f59fc60f authored by Max Wittig's avatar Max Wittig Committed by GitHub
Browse files

[Feat][CLI] enforce-include-usage (#19695)


Signed-off-by: default avatarMax Wittig <max.wittig@siemens.com>
parent 879f69be
...@@ -1190,6 +1190,7 @@ async def init_app_state( ...@@ -1190,6 +1190,7 @@ async def init_app_state(
tool_parser=args.tool_call_parser, tool_parser=args.tool_call_parser,
reasoning_parser=args.reasoning_parser, reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None ) if model_config.runner_type == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion( state.openai_serving_completion = OpenAIServingCompletion(
engine_client, engine_client,
...@@ -1197,6 +1198,7 @@ async def init_app_state( ...@@ -1197,6 +1198,7 @@ async def init_app_state(
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None ) if model_config.runner_type == "generate" else None
state.openai_serving_pooling = OpenAIServingPooling( state.openai_serving_pooling = OpenAIServingPooling(
engine_client, engine_client,
......
...@@ -272,6 +272,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ...@@ -272,6 +272,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action='store_true', action='store_true',
default=False, default=False,
help="If set to True, enable prompt_tokens_details in usage.") help="If set to True, enable prompt_tokens_details in usage.")
parser.add_argument(
"--enable-force-include-usage",
action='store_true',
default=False,
help="If set to True, including usage on every request.")
parser.add_argument( parser.add_argument(
"--enable-server-load-tracking", "--enable-server-load-tracking",
action='store_true', action='store_true',
......
...@@ -64,12 +64,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -64,12 +64,14 @@ class OpenAIServingChat(OpenAIServing):
enable_auto_tools: bool = False, enable_auto_tools: bool = False,
tool_parser: Optional[str] = None, tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False, enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
) -> None: ) -> None:
super().__init__(engine_client=engine_client, super().__init__(engine_client=engine_client,
model_config=model_config, model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids) return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage)
self.response_role = response_role self.response_role = response_role
self.chat_template = chat_template self.chat_template = chat_template
...@@ -110,6 +112,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -110,6 +112,7 @@ class OpenAIServingChat(OpenAIServing):
"been registered") from e "been registered") from e
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = ( self.default_sampling_params = (
self.model_config.get_diff_sampling_param()) self.model_config.get_diff_sampling_param())
if self.default_sampling_params: if self.default_sampling_params:
...@@ -261,8 +264,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -261,8 +264,14 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
request, result_generator, request_id, model_name, request,
conversation, tokenizer, request_metadata) result_generator,
request_id,
model_name,
conversation,
tokenizer,
request_metadata,
enable_force_include_usage=self.enable_force_include_usage)
try: try:
return await self.chat_completion_full_generator( return await self.chat_completion_full_generator(
...@@ -405,6 +414,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -405,6 +414,7 @@ class OpenAIServingChat(OpenAIServing):
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
...@@ -471,7 +481,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -471,7 +481,8 @@ class OpenAIServingChat(OpenAIServing):
stream_options = request.stream_options stream_options = request.stream_options
if stream_options: if stream_options:
include_usage = stream_options.include_usage include_usage = stream_options.include_usage \
or enable_force_include_usage
include_continuous_usage = include_usage and \ include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats stream_options.continuous_usage_stats
else: else:
......
...@@ -52,12 +52,14 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -52,12 +52,14 @@ class OpenAIServingCompletion(OpenAIServing):
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
): ):
super().__init__(engine_client=engine_client, super().__init__(engine_client=engine_client,
model_config=model_config, model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids) return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage)
self.default_sampling_params = ( self.default_sampling_params = (
self.model_config.get_diff_sampling_param()) self.model_config.get_diff_sampling_param())
if self.default_sampling_params: if self.default_sampling_params:
...@@ -227,7 +229,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -227,7 +229,8 @@ class OpenAIServingCompletion(OpenAIServing):
model_name, model_name,
num_prompts=num_prompts, num_prompts=num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
request_metadata=request_metadata) request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage)
# Non-streaming response # Non-streaming response
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
...@@ -289,6 +292,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -289,6 +292,7 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts: int, num_prompts: int,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts previous_text_lens = [0] * num_choices * num_prompts
...@@ -298,7 +302,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -298,7 +302,8 @@ class OpenAIServingCompletion(OpenAIServing):
stream_options = request.stream_options stream_options = request.stream_options
if stream_options: if stream_options:
include_usage = stream_options.include_usage include_usage = stream_options.include_usage or \
enable_force_include_usage
include_continuous_usage = include_usage and \ include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats stream_options.continuous_usage_stats
else: else:
......
...@@ -132,7 +132,7 @@ RequestT = TypeVar("RequestT", bound=AnyRequest) ...@@ -132,7 +132,7 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
class RequestProcessingMixin(BaseModel): class RequestProcessingMixin(BaseModel):
""" """
Mixin for request processing, Mixin for request processing,
handling prompt preparation and engine input. handling prompt preparation and engine input.
""" """
request_prompts: Optional[Sequence[RequestPrompt]] = [] request_prompts: Optional[Sequence[RequestPrompt]] = []
...@@ -144,7 +144,7 @@ class RequestProcessingMixin(BaseModel): ...@@ -144,7 +144,7 @@ class RequestProcessingMixin(BaseModel):
class ResponseGenerationMixin(BaseModel): class ResponseGenerationMixin(BaseModel):
""" """
Mixin for response generation, Mixin for response generation,
managing result generators and final batch results. managing result generators and final batch results.
""" """
result_generator: Optional[AsyncGenerator[tuple[int, Union[ result_generator: Optional[AsyncGenerator[tuple[int, Union[
...@@ -208,6 +208,7 @@ class OpenAIServing: ...@@ -208,6 +208,7 @@ class OpenAIServing:
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
): ):
super().__init__() super().__init__()
...@@ -219,6 +220,7 @@ class OpenAIServing: ...@@ -219,6 +220,7 @@ class OpenAIServing:
self.request_logger = request_logger self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.enable_force_include_usage = enable_force_include_usage
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment