Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1ac3de09
Unverified
Commit
1ac3de09
authored
Sep 25, 2024
by
Adam Tilghman
Committed by
GitHub
Sep 25, 2024
Browse files
[Frontend] OpenAI server: propagate usage accounting to FastAPI middleware layer (#8672)
parent
3e073e66
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
11 deletions
+57
-11
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+5
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+23
-3
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+29
-8
No files found.
vllm/entrypoints/openai/protocol.py
View file @
1ac3de09
...
@@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel):
...
@@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens
:
Optional
[
int
]
=
0
completion_tokens
:
Optional
[
int
]
=
0
class
RequestResponseMetadata
(
BaseModel
):
request_id
:
str
final_usage_info
:
Optional
[
UsageInfo
]
=
None
class
JsonSchemaResponseFormat
(
OpenAIBaseModel
):
class
JsonSchemaResponseFormat
(
OpenAIBaseModel
):
name
:
str
name
:
str
description
:
Optional
[
str
]
=
None
description
:
Optional
[
str
]
=
None
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
1ac3de09
...
@@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaFunctionCall
,
DeltaMessage
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ErrorResponse
,
FunctionCall
,
ToolCall
,
UsageInfo
)
DeltaToolCall
,
ErrorResponse
,
FunctionCall
,
RequestResponseMetadata
,
ToolCall
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
LoRAModulePath
,
LoRAModulePath
,
OpenAIServing
,
OpenAIServing
,
...
@@ -175,6 +176,11 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -175,6 +176,11 @@ class OpenAIServingChat(OpenAIServing):
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
request_id
=
f
"chat-
{
random_uuid
()
}
"
request_id
=
f
"chat-
{
random_uuid
()
}
"
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
raw_request
:
raw_request
.
state
.
request_metadata
=
request_metadata
try
:
try
:
guided_decode_logits_processor
=
(
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
...
@@ -241,11 +247,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -241,11 +247,13 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
# Streaming response
if
request
.
stream
:
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
return
self
.
chat_completion_stream_generator
(
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
,
request_metadata
)
try
:
try
:
return
await
self
.
chat_completion_full_generator
(
return
await
self
.
chat_completion_full_generator
(
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
,
request_metadata
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
@@ -262,6 +270,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -262,6 +270,7 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
request_id
:
str
,
conversation
:
List
[
ConversationMessage
],
conversation
:
List
[
ConversationMessage
],
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
model_name
=
self
.
base_model_paths
[
0
].
name
model_name
=
self
.
base_model_paths
[
0
].
name
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
...
@@ -580,6 +589,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -580,6 +589,13 @@ class OpenAIServingChat(OpenAIServing):
exclude_unset
=
True
,
exclude_none
=
True
))
exclude_unset
=
True
,
exclude_none
=
True
))
yield
f
"data:
{
final_usage_data
}
\n\n
"
yield
f
"data:
{
final_usage_data
}
\n\n
"
# report to FastAPI middleware aggregate usage across all choices
num_completion_tokens
=
sum
(
previous_num_tokens
)
request_metadata
.
final_usage_info
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_completion_tokens
,
total_tokens
=
num_prompt_tokens
+
num_completion_tokens
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
logger
.
error
(
"error in chat completion stream generator: %s"
,
e
)
logger
.
error
(
"error in chat completion stream generator: %s"
,
e
)
...
@@ -595,6 +611,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -595,6 +611,7 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
request_id
:
str
,
conversation
:
List
[
ConversationMessage
],
conversation
:
List
[
ConversationMessage
],
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
Union
[
ErrorResponse
,
ChatCompletionResponse
]:
)
->
Union
[
ErrorResponse
,
ChatCompletionResponse
]:
model_name
=
self
.
base_model_paths
[
0
].
name
model_name
=
self
.
base_model_paths
[
0
].
name
...
@@ -714,6 +731,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -714,6 +731,9 @@ class OpenAIServingChat(OpenAIServing):
completion_tokens
=
num_generated_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
)
request_metadata
.
final_usage_info
=
usage
response
=
ChatCompletionResponse
(
response
=
ChatCompletionResponse
(
id
=
request_id
,
id
=
request_id
,
created
=
created_time
,
created
=
created_time
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
1ac3de09
...
@@ -18,7 +18,9 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
...
@@ -18,7 +18,9 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
CompletionStreamResponse
,
ErrorResponse
,
UsageInfo
)
ErrorResponse
,
RequestResponseMetadata
,
UsageInfo
)
# yapf: enable
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
LoRAModulePath
,
LoRAModulePath
,
...
@@ -94,6 +96,10 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -94,6 +96,10 @@ class OpenAIServingCompletion(OpenAIServing):
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
raw_request
:
raw_request
.
state
.
request_metadata
=
request_metadata
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
List
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
...
@@ -165,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -165,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
# Streaming response
if
stream
:
if
stream
:
return
self
.
completion_stream_generator
(
request
,
return
self
.
completion_stream_generator
(
request
,
result_generator
,
result_generator
,
request_id
,
request_id
,
created_time
,
created_time
,
model_name
,
model_name
,
num_prompts
=
len
(
prompts
),
num_prompts
=
len
(
prompts
),
tokenizer
=
tokenizer
)
tokenizer
=
tokenizer
,
request_metadata
=
request_metadata
)
# Non-streaming response
# Non-streaming response
final_res_batch
:
List
[
Optional
[
RequestOutput
]]
=
[
None
]
*
len
(
prompts
)
final_res_batch
:
List
[
Optional
[
RequestOutput
]]
=
[
None
]
*
len
(
prompts
)
...
@@ -198,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -198,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
,
created_time
,
model_name
,
model_name
,
tokenizer
,
tokenizer
,
request_metadata
,
)
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
return
self
.
create_error_response
(
"Client disconnected"
)
...
@@ -227,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -227,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name
:
str
,
model_name
:
str
,
num_prompts
:
int
,
num_prompts
:
int
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
previous_text_lens
=
[
0
]
*
num_choices
*
num_prompts
previous_text_lens
=
[
0
]
*
num_choices
*
num_prompts
...
@@ -346,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -346,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing):
exclude_unset
=
False
,
exclude_none
=
True
))
exclude_unset
=
False
,
exclude_none
=
True
))
yield
f
"data:
{
final_usage_data
}
\n\n
"
yield
f
"data:
{
final_usage_data
}
\n\n
"
# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens
=
sum
(
num_prompt_tokens
)
total_completion_tokens
=
sum
(
previous_num_tokens
)
request_metadata
.
final_usage_info
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
data
=
self
.
create_streaming_error_response
(
str
(
e
))
data
=
self
.
create_streaming_error_response
(
str
(
e
))
...
@@ -360,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -360,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
:
int
,
created_time
:
int
,
model_name
:
str
,
model_name
:
str
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
CompletionResponse
:
)
->
CompletionResponse
:
choices
:
List
[
CompletionResponseChoice
]
=
[]
choices
:
List
[
CompletionResponseChoice
]
=
[]
num_prompt_tokens
=
0
num_prompt_tokens
=
0
...
@@ -433,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -433,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
)
request_metadata
.
final_usage_info
=
usage
return
CompletionResponse
(
return
CompletionResponse
(
id
=
request_id
,
id
=
request_id
,
created
=
created_time
,
created
=
created_time
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment