Unverified Commit adcf682f authored by Ekagra Ranjan's avatar Ekagra Ranjan Committed by GitHub
Browse files

[Audio] Improve Audio Inference Scripts (offline/online) (#29279)


Signed-off-by: default avatarEkagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
parent 21de6d4b
...@@ -495,17 +495,17 @@ def main(args): ...@@ -495,17 +495,17 @@ def main(args):
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
) )
def get_input(start, end):
mm_data = req_data.multi_modal_data mm_data = req_data.multi_modal_data
if not mm_data: if not mm_data:
mm_data = {} mm_data = {}
if audio_count > 0: if end - start > 0:
mm_data = { mm_data = {
"audio": [ "audio": [
asset.audio_and_sample_rate for asset in audio_assets[:audio_count] asset.audio_and_sample_rate for asset in audio_assets[start:end]
] ]
} }
assert args.num_prompts > 0
inputs = {"multi_modal_data": mm_data} inputs = {"multi_modal_data": mm_data}
if req_data.prompt: if req_data.prompt:
...@@ -513,9 +513,22 @@ def main(args): ...@@ -513,9 +513,22 @@ def main(args):
else: else:
inputs["prompt_token_ids"] = req_data.prompt_token_ids inputs["prompt_token_ids"] = req_data.prompt_token_ids
if args.num_prompts > 1: return inputs
# Batch inference # Batch inference
assert args.num_prompts > 0
if audio_count != 1:
inputs = get_input(0, audio_count)
inputs = [inputs] * args.num_prompts inputs = [inputs] * args.num_prompts
else:
# For single audio input, we need to vary the audio input
# to avoid deduplication in vLLM engine.
inputs = []
for i in range(args.num_prompts):
start = i % len(audio_assets)
inp = get_input(start, start + 1)
inputs.append(inp)
# Add LoRA request if applicable # Add LoRA request if applicable
lora_request = ( lora_request = (
req_data.lora_requests * args.num_prompts if req_data.lora_requests else None req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
......
...@@ -18,6 +18,7 @@ The script performs: ...@@ -18,6 +18,7 @@ The script performs:
2. Streaming transcription using raw HTTP request to the vLLM server. 2. Streaming transcription using raw HTTP request to the vLLM server.
""" """
import argparse
import asyncio import asyncio
from openai import AsyncOpenAI, OpenAI from openai import AsyncOpenAI, OpenAI
...@@ -25,14 +26,14 @@ from openai import AsyncOpenAI, OpenAI ...@@ -25,14 +26,14 @@ from openai import AsyncOpenAI, OpenAI
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
def sync_openai(audio_path: str, client: OpenAI): def sync_openai(audio_path: str, client: OpenAI, model: str):
""" """
Perform synchronous transcription using OpenAI-compatible API. Perform synchronous transcription using OpenAI-compatible API.
""" """
with open(audio_path, "rb") as f: with open(audio_path, "rb") as f:
transcription = client.audio.transcriptions.create( transcription = client.audio.transcriptions.create(
file=f, file=f,
model="openai/whisper-large-v3", model=model,
language="en", language="en",
response_format="json", response_format="json",
temperature=0.0, temperature=0.0,
...@@ -42,18 +43,18 @@ def sync_openai(audio_path: str, client: OpenAI): ...@@ -42,18 +43,18 @@ def sync_openai(audio_path: str, client: OpenAI):
repetition_penalty=1.3, repetition_penalty=1.3,
), ),
) )
print("transcription result:", transcription.text) print("transcription result [sync]:", transcription.text)
async def stream_openai_response(audio_path: str, client: AsyncOpenAI): async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: str):
""" """
Perform asynchronous transcription using OpenAI-compatible API. Perform asynchronous transcription using OpenAI-compatible API.
""" """
print("\ntranscription result:", end=" ") print("\ntranscription result [stream]:", end=" ")
with open(audio_path, "rb") as f: with open(audio_path, "rb") as f:
transcription = await client.audio.transcriptions.create( transcription = await client.audio.transcriptions.create(
file=f, file=f,
model="openai/whisper-large-v3", model=model,
language="en", language="en",
response_format="json", response_format="json",
temperature=0.0, temperature=0.0,
...@@ -72,7 +73,47 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI): ...@@ -72,7 +73,47 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI):
print() # Final newline after stream ends print() # Final newline after stream ends
def main(): def stream_api_response(audio_path: str, model: str, openai_api_base: str):
"""
Perform streaming transcription using raw HTTP requests to the vLLM API server.
"""
import json
import os
import requests
api_url = f"{openai_api_base}/audio/transcriptions"
headers = {"User-Agent": "Transcription-Client"}
with open(audio_path, "rb") as f:
files = {"file": (os.path.basename(audio_path), f)}
data = {
"stream": "true",
"model": model,
"language": "en",
"response_format": "json",
}
print("\ntranscription result [stream]:", end=" ")
response = requests.post(
api_url, headers=headers, files=files, data=data, stream=True
)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
):
if chunk:
data = chunk[len("data: ") :]
data = json.loads(data.decode("utf-8"))
data = data["choices"][0]
delta = data["delta"]["content"]
print(delta, end="", flush=True)
finish_reason = data.get("finish_reason")
if finish_reason is not None:
print(f"\n[Stream finished reason: {finish_reason}]")
break
def main(args):
mary_had_lamb = str(AudioAsset("mary_had_lamb").get_local_path()) mary_had_lamb = str(AudioAsset("mary_had_lamb").get_local_path())
winning_call = str(AudioAsset("winning_call").get_local_path()) winning_call = str(AudioAsset("winning_call").get_local_path())
...@@ -84,14 +125,41 @@ def main(): ...@@ -84,14 +125,41 @@ def main():
base_url=openai_api_base, base_url=openai_api_base,
) )
sync_openai(mary_had_lamb, client) model = client.models.list().data[0].id
print(f"Using model: {model}")
# Run the synchronous function
sync_openai(args.audio_path if args.audio_path else mary_had_lamb, client, model)
# Run the asynchronous function # Run the asynchronous function
if "openai" in model:
client = AsyncOpenAI( client = AsyncOpenAI(
api_key=openai_api_key, api_key=openai_api_key,
base_url=openai_api_base, base_url=openai_api_base,
) )
asyncio.run(stream_openai_response(winning_call, client)) asyncio.run(
stream_openai_response(
args.audio_path if args.audio_path else winning_call, client, model
)
)
else:
stream_api_response(
args.audio_path if args.audio_path else winning_call,
model,
openai_api_base,
)
if __name__ == "__main__": if __name__ == "__main__":
main() # setup argparser
parser = argparse.ArgumentParser(
description="OpenAI Transcription Client using vLLM API Server"
)
parser.add_argument(
"--audio_path",
type=str,
default=None,
help="The path to the audio file to transcribe.",
)
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment