Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fa82b938
Unverified
Commit
fa82b938
authored
Mar 06, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Mar 06, 2025
Browse files
[Frontend][Docs] Transcription API streaming (#13301)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
69ff99fd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
297 additions
and
26 deletions
+297
-26
docs/source/serving/openai_compatible_server.md
docs/source/serving/openai_compatible_server.md
+4
-0
examples/online_serving/openai_transcription_client.py
examples/online_serving/openai_transcription_client.py
+51
-8
tests/entrypoints/openai/test_transcription_validation.py
tests/entrypoints/openai/test_transcription_validation.py
+72
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+39
-1
vllm/entrypoints/openai/serving_transcription.py
vllm/entrypoints/openai/serving_transcription.py
+131
-17
No files found.
docs/source/serving/openai_compatible_server.md
View file @
fa82b938
...
@@ -379,6 +379,10 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
...
@@ -379,6 +379,10 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
:::{note}
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
:::
<!-- TODO: api enforced limits + uploading audios -->
<!-- TODO: api enforced limits + uploading audios -->
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
...
...
examples/online_serving/openai_transcription_client.py
View file @
fa82b938
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
json
import
httpx
from
openai
import
OpenAI
from
openai
import
OpenAI
from
vllm.assets.audio
import
AudioAsset
from
vllm.assets.audio
import
AudioAsset
...
@@ -13,11 +17,50 @@ client = OpenAI(
...
@@ -13,11 +17,50 @@ client = OpenAI(
api_key
=
openai_api_key
,
api_key
=
openai_api_key
,
base_url
=
openai_api_base
,
base_url
=
openai_api_base
,
)
)
with
open
(
str
(
mary_had_lamb
),
"rb"
)
as
f
:
transcription
=
client
.
audio
.
transcriptions
.
create
(
file
=
f
,
def
sync_openai
():
model
=
"openai/whisper-large-v3"
,
with
open
(
str
(
mary_had_lamb
),
"rb"
)
as
f
:
language
=
"en"
,
transcription
=
client
.
audio
.
transcriptions
.
create
(
response_format
=
"text"
,
file
=
f
,
temperature
=
0.0
)
model
=
"openai/whisper-small"
,
print
(
"transcription result:"
,
transcription
)
language
=
"en"
,
response_format
=
"json"
,
temperature
=
0.0
)
print
(
"transcription result:"
,
transcription
.
text
)
sync_openai
()
# OpenAI Transcription API client does not support streaming.
async
def
stream_openai_response
():
data
=
{
"language"
:
"en"
,
'stream'
:
True
,
"model"
:
"openai/whisper-large-v3"
,
}
url
=
openai_api_base
+
"/audio/transcriptions"
print
(
"transcription result:"
,
end
=
' '
)
async
with
httpx
.
AsyncClient
()
as
client
:
with
open
(
str
(
winning_call
),
"rb"
)
as
f
:
async
with
client
.
stream
(
'POST'
,
url
,
files
=
{
'file'
:
f
},
data
=
data
)
as
response
:
async
for
line
in
response
.
aiter_lines
():
# Each line is a JSON object prefixed with 'data: '
if
line
:
if
line
.
startswith
(
'data: '
):
line
=
line
[
len
(
'data: '
):]
# Last chunk, stream ends
if
line
.
strip
()
==
'[DONE]'
:
break
# Parse the JSON response
chunk
=
json
.
loads
(
line
)
# Extract and print the content
content
=
chunk
[
'choices'
][
0
].
get
(
'delta'
,
{}).
get
(
'content'
)
print
(
content
,
end
=
''
)
# Run the asynchronous function
asyncio
.
run
(
stream_openai_response
())
tests/entrypoints/openai/test_transcription_validation.py
View file @
fa82b938
...
@@ -3,12 +3,14 @@
...
@@ -3,12 +3,14 @@
# imports for guided decoding tests
# imports for guided decoding tests
import
io
import
io
import
json
import
json
from
unittest.mock
import
patch
import
librosa
import
librosa
import
numpy
as
np
import
numpy
as
np
import
openai
import
openai
import
pytest
import
pytest
import
soundfile
as
sf
import
soundfile
as
sf
from
openai._base_client
import
AsyncAPIClient
from
vllm.assets.audio
import
AudioAsset
from
vllm.assets.audio
import
AudioAsset
...
@@ -120,3 +122,73 @@ async def test_completion_endpoints():
...
@@ -120,3 +122,73 @@ async def test_completion_endpoints():
res
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Hello"
)
res
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Hello"
)
assert
res
.
code
==
400
assert
res
.
code
==
400
assert
res
.
message
==
"The model does not support Completions API"
assert
res
.
message
==
"The model does not support Completions API"
@
pytest
.
mark
.
asyncio
async
def
test_streaming_response
(
winning_call
):
model_name
=
"openai/whisper-small"
server_args
=
[
"--enforce-eager"
]
transcription
=
""
with
RemoteOpenAIServer
(
model_name
,
server_args
)
as
remote_server
:
client
=
remote_server
.
get_async_client
()
res_no_stream
=
await
client
.
audio
.
transcriptions
.
create
(
model
=
model_name
,
file
=
winning_call
,
response_format
=
"json"
,
language
=
"en"
,
temperature
=
0.0
)
# Unfortunately this only works when the openai client is patched
# to use streaming mode, not exposed in the transcription api.
original_post
=
AsyncAPIClient
.
post
async
def
post_with_stream
(
*
args
,
**
kwargs
):
kwargs
[
'stream'
]
=
True
return
await
original_post
(
*
args
,
**
kwargs
)
with
patch
.
object
(
AsyncAPIClient
,
"post"
,
new
=
post_with_stream
):
client
=
remote_server
.
get_async_client
()
res
=
await
client
.
audio
.
transcriptions
.
create
(
model
=
model_name
,
file
=
winning_call
,
language
=
"en"
,
temperature
=
0.0
,
extra_body
=
dict
(
stream
=
True
))
# Reconstruct from chunks and validate
async
for
chunk
in
res
:
# just a chunk
text
=
chunk
.
choices
[
0
][
'delta'
][
'content'
]
transcription
+=
text
assert
transcription
==
res_no_stream
.
text
@
pytest
.
mark
.
asyncio
async
def
test_stream_options
(
winning_call
):
model_name
=
"openai/whisper-small"
server_args
=
[
"--enforce-eager"
]
with
RemoteOpenAIServer
(
model_name
,
server_args
)
as
remote_server
:
original_post
=
AsyncAPIClient
.
post
async
def
post_with_stream
(
*
args
,
**
kwargs
):
kwargs
[
'stream'
]
=
True
return
await
original_post
(
*
args
,
**
kwargs
)
with
patch
.
object
(
AsyncAPIClient
,
"post"
,
new
=
post_with_stream
):
client
=
remote_server
.
get_async_client
()
res
=
await
client
.
audio
.
transcriptions
.
create
(
model
=
model_name
,
file
=
winning_call
,
language
=
"en"
,
temperature
=
0.0
,
extra_body
=
dict
(
stream
=
True
,
stream_include_usage
=
True
,
stream_continuous_usage_stats
=
True
))
final
=
False
continuous
=
True
async
for
chunk
in
res
:
if
not
len
(
chunk
.
choices
):
# final usage sent
final
=
True
else
:
continuous
=
continuous
and
hasattr
(
chunk
,
'usage'
)
assert
final
and
continuous
vllm/entrypoints/openai/protocol.py
View file @
fa82b938
...
@@ -1285,6 +1285,21 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
...
@@ -1285,6 +1285,21 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
class
TranscriptionResponseStreamChoice
(
OpenAIBaseModel
):
delta
:
DeltaMessage
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
class
TranscriptionStreamResponse
(
OpenAIBaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"trsc-
{
random_uuid
()
}
"
)
object
:
Literal
[
"transcription.chunk"
]
=
"transcription.chunk"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
list
[
TranscriptionResponseStreamChoice
]
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
class
BatchRequestInput
(
OpenAIBaseModel
):
class
BatchRequestInput
(
OpenAIBaseModel
):
"""
"""
The per-line object of the batch input file.
The per-line object of the batch input file.
...
@@ -1510,6 +1525,15 @@ class TranscriptionRequest(OpenAIBaseModel):
...
@@ -1510,6 +1525,15 @@ class TranscriptionRequest(OpenAIBaseModel):
timestamps incurs additional latency.
timestamps incurs additional latency.
"""
"""
stream
:
Optional
[
bool
]
=
False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Completion endpoint.
"""
# Flattened stream option to simplify form data.
stream_include_usage
:
Optional
[
bool
]
=
False
stream_continuous_usage_stats
:
Optional
[
bool
]
=
False
# Default sampling parameters for transcription requests.
# Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS
:
dict
=
{
_DEFAULT_SAMPLING_PARAMS
:
dict
=
{
"temperature"
:
0
,
"temperature"
:
0
,
...
@@ -1530,7 +1554,21 @@ class TranscriptionRequest(OpenAIBaseModel):
...
@@ -1530,7 +1554,21 @@ class TranscriptionRequest(OpenAIBaseModel):
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
return
SamplingParams
.
from_optional
(
temperature
=
temperature
,
return
SamplingParams
.
from_optional
(
temperature
=
temperature
,
max_tokens
=
max_tokens
)
max_tokens
=
max_tokens
,
output_kind
=
RequestOutputKind
.
DELTA
if
self
.
stream
\
else
RequestOutputKind
.
FINAL_ONLY
)
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
validate_stream_options
(
cls
,
data
):
stream_opts
=
[
"stream_include_usage"
,
"stream_continuous_usage_stats"
]
stream
=
data
.
get
(
"stream"
,
False
)
if
any
(
bool
(
data
.
get
(
so
,
False
))
for
so
in
stream_opts
)
and
not
stream
:
raise
ValueError
(
"Stream options can only be defined when `stream=True`."
)
return
data
# Transcription response objects
# Transcription response objects
...
...
vllm/entrypoints/openai/serving_transcription.py
View file @
fa82b938
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
asyncio
import
io
import
io
import
time
from
collections.abc
import
AsyncGenerator
from
collections.abc
import
AsyncGenerator
from
typing
import
Optional
,
Union
,
cast
from
math
import
ceil
from
typing
import
Final
,
Optional
,
Union
,
cast
from
fastapi
import
Request
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
from
vllm.entrypoints.openai.protocol
import
(
RequestResponseMetadata
,
DeltaMessage
,
ErrorResponse
,
RequestResponseMetadata
,
TranscriptionRequest
,
TranscriptionRequest
,
TranscriptionResponse
,
TranscriptionResponseStreamChoice
,
TranscriptionResponse
,
TranscriptionStreamResponse
,
UsageInfo
)
TranscriptionResponseVerbose
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.data
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.utils
import
PlaceholderModule
from
vllm.utils
import
PlaceholderModule
try
:
try
:
...
@@ -140,8 +142,6 @@ ISO639_1_OTHER_LANGS = {
...
@@ -140,8 +142,6 @@ ISO639_1_OTHER_LANGS = {
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB
=
25
MAX_AUDIO_CLIP_FILESIZE_MB
=
25
# TODO get from processor.feature_extractor.chunk_length
MAX_AUDIO_CLIP_DURATION_S
=
30
class
OpenAIServingTranscription
(
OpenAIServing
):
class
OpenAIServingTranscription
(
OpenAIServing
):
...
@@ -163,6 +163,11 @@ class OpenAIServingTranscription(OpenAIServing):
...
@@ -163,6 +163,11 @@ class OpenAIServingTranscription(OpenAIServing):
self
.
default_sampling_params
=
(
self
.
default_sampling_params
=
(
self
.
model_config
.
get_diff_sampling_param
())
self
.
model_config
.
get_diff_sampling_param
())
processor
=
cached_get_processor
(
model_config
.
model
)
self
.
max_audio_clip_s
=
processor
.
feature_extractor
.
chunk_length
self
.
model_sr
=
processor
.
feature_extractor
.
sampling_rate
self
.
hop_length
=
processor
.
feature_extractor
.
hop_length
if
self
.
default_sampling_params
:
if
self
.
default_sampling_params
:
logger
.
info
(
logger
.
info
(
"Overwriting default completion sampling param with: %s"
,
"Overwriting default completion sampling param with: %s"
,
...
@@ -172,7 +177,7 @@ class OpenAIServingTranscription(OpenAIServing):
...
@@ -172,7 +177,7 @@ class OpenAIServingTranscription(OpenAIServing):
self
,
self
,
request
:
TranscriptionRequest
,
request
:
TranscriptionRequest
,
audio_data
:
bytes
,
audio_data
:
bytes
,
)
->
PromptType
:
)
->
tuple
[
PromptType
,
float
]
:
# Validate request
# Validate request
# TODO language should be optional and can be guessed.
# TODO language should be optional and can be guessed.
# For now we default to en. See
# For now we default to en. See
...
@@ -198,9 +203,11 @@ class OpenAIServingTranscription(OpenAIServing):
...
@@ -198,9 +203,11 @@ class OpenAIServingTranscription(OpenAIServing):
with
io
.
BytesIO
(
audio_data
)
as
bytes_
:
with
io
.
BytesIO
(
audio_data
)
as
bytes_
:
y
,
sr
=
librosa
.
load
(
bytes_
)
y
,
sr
=
librosa
.
load
(
bytes_
)
if
librosa
.
get_duration
(
y
=
y
,
sr
=
sr
)
>
MAX_AUDIO_CLIP_DURATION_S
:
duration
=
librosa
.
get_duration
(
y
=
y
,
sr
=
sr
)
if
duration
>
self
.
max_audio_clip_s
:
raise
ValueError
(
raise
ValueError
(
f
"Maximum clip duration (
{
MAX_AUDIO_CLIP_DURATION_S
}
s) "
f
"Maximum clip duration (
{
self
.
max_audio_clip_s
}
s) "
"exceeded."
)
"exceeded."
)
prompt
=
{
prompt
=
{
...
@@ -213,13 +220,13 @@ class OpenAIServingTranscription(OpenAIServing):
...
@@ -213,13 +220,13 @@ class OpenAIServingTranscription(OpenAIServing):
"decoder_prompt"
:
"decoder_prompt"
:
f
"<|startoftranscript|>
{
lang_token
}
<|transcribe|><|notimestamps|>
{
request
.
prompt
}
"
f
"<|startoftranscript|>
{
lang_token
}
<|transcribe|><|notimestamps|>
{
request
.
prompt
}
"
}
}
return
cast
(
PromptType
,
prompt
)
return
cast
(
PromptType
,
prompt
)
,
duration
# TODO (varun) : Make verbose response work !
# TODO (varun) : Make verbose response work !
async
def
create_transcription
(
async
def
create_transcription
(
self
,
audio_data
:
bytes
,
request
:
TranscriptionRequest
,
self
,
audio_data
:
bytes
,
request
:
TranscriptionRequest
,
raw_request
:
Request
raw_request
:
Request
)
->
Union
[
TranscriptionResponse
,
TranscriptionResponseVerbose
,
)
->
Union
[
TranscriptionResponse
,
AsyncGenerator
[
str
,
None
]
,
ErrorResponse
]:
ErrorResponse
]:
"""Transcription API similar to OpenAI's API.
"""Transcription API similar to OpenAI's API.
...
@@ -240,8 +247,7 @@ class OpenAIServingTranscription(OpenAIServing):
...
@@ -240,8 +247,7 @@ class OpenAIServingTranscription(OpenAIServing):
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"Currently only support response_format `text` or `json`"
)
"Currently only support response_format `text` or `json`"
)
# TODO cmpl->transcription?
request_id
=
f
"trsc-
{
self
.
_base_request_id
(
raw_request
)
}
"
request_id
=
f
"cmpl-
{
self
.
_base_request_id
(
raw_request
)
}
"
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
raw_request
:
if
raw_request
:
...
@@ -261,7 +267,7 @@ class OpenAIServingTranscription(OpenAIServing):
...
@@ -261,7 +267,7 @@ class OpenAIServingTranscription(OpenAIServing):
"Currently do not support PromptAdapter for Transcription."
"Currently do not support PromptAdapter for Transcription."
)
)
prompt
=
await
self
.
_preprocess_transcription
(
prompt
,
duration_s
=
await
self
.
_preprocess_transcription
(
request
=
request
,
request
=
request
,
audio_data
=
audio_data
,
audio_data
=
audio_data
,
)
)
...
@@ -293,7 +299,12 @@ class OpenAIServingTranscription(OpenAIServing):
...
@@ -293,7 +299,12 @@ class OpenAIServingTranscription(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
# TODO(rob): figure out a way to pipe streaming in.
if
request
.
stream
:
return
self
.
transcription_stream_generator
(
request
,
result_generator
,
request_id
,
request_metadata
,
duration_s
)
# Non-streaming response.
# Non-streaming response.
try
:
try
:
assert
result_generator
is
not
None
assert
result_generator
is
not
None
...
@@ -305,3 +316,106 @@ class OpenAIServingTranscription(OpenAIServing):
...
@@ -305,3 +316,106 @@ class OpenAIServingTranscription(OpenAIServing):
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
async
def
transcription_stream_generator
(
self
,
request
:
TranscriptionRequest
,
result_generator
:
AsyncGenerator
[
RequestOutput
,
None
],
request_id
:
str
,
request_metadata
:
RequestResponseMetadata
,
audio_duration_s
:
float
)
->
AsyncGenerator
[
str
,
None
]:
created_time
=
int
(
time
.
time
())
model_name
=
request
.
model
chunk_object_type
:
Final
=
"transcription.chunk"
completion_tokens
=
0
num_prompt_tokens
=
0
include_usage
=
request
.
stream_include_usage
\
if
request
.
stream_include_usage
else
False
include_continuous_usage
=
request
.
stream_continuous_usage_stats
\
if
include_usage
and
request
.
stream_continuous_usage_stats
\
else
False
try
:
async
for
res
in
result_generator
:
# On first result.
if
res
.
prompt_token_ids
is
not
None
:
# Do not account the 4-tokens `<|startoftranscript|>..`
# Could be negative when language token is not specified.
num_prompt_tokens
=
max
(
len
(
res
.
prompt_token_ids
)
-
4
,
0
)
# NOTE(NickLucche) user can't pass encoder prompts directly
# at least not to Whisper. One indicator of the encoder
# amount of processing is the log-mel spectogram length.
num_prompt_tokens
+=
ceil
(
audio_duration_s
*
self
.
model_sr
/
self
.
hop_length
)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
# Just one output (n=1) supported.
assert
len
(
res
.
outputs
)
==
1
output
=
res
.
outputs
[
0
]
delta_message
=
DeltaMessage
(
content
=
output
.
text
)
completion_tokens
+=
len
(
output
.
token_ids
)
if
output
.
finish_reason
is
None
:
# Still generating, send delta update.
choice_data
=
TranscriptionResponseStreamChoice
(
delta
=
delta_message
)
else
:
# Model is finished generating.
choice_data
=
TranscriptionResponseStreamChoice
(
delta
=
delta_message
,
finish_reason
=
output
.
finish_reason
,
stop_reason
=
output
.
stop_reason
)
chunk
=
TranscriptionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
# handle usage stats if requested & if continuous
if
include_continuous_usage
:
chunk
.
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
num_prompt_tokens
+
completion_tokens
,
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
# Once the final token is handled, if stream_options.include_usage
# is sent, send the usage.
if
include_usage
:
final_usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
num_prompt_tokens
+
completion_tokens
)
final_usage_chunk
=
TranscriptionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[],
model
=
model_name
,
usage
=
final_usage
)
final_usage_data
=
(
final_usage_chunk
.
model_dump_json
(
exclude_unset
=
True
,
exclude_none
=
True
))
yield
f
"data:
{
final_usage_data
}
\n\n
"
# report to FastAPI middleware aggregate usage across all choices
request_metadata
.
final_usage_info
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
num_prompt_tokens
+
completion_tokens
)
except
Exception
as
e
:
# TODO: Use a vllm-specific Validation Error
logger
.
exception
(
"Error in chat completion stream generator."
)
data
=
self
.
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
data
}
\n\n
"
# Send the final done message after all response.n are finished
yield
"data: [DONE]
\n\n
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment