openai_transcription_client.py 4.99 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
This script demonstrates how to use the vLLM API server to perform audio
transcription with the `openai/whisper-large-v3` model.

Before running this script, you must start the vLLM server with the following command:

    vllm serve openai/whisper-large-v3

Requirements:
- vLLM with audio support
- openai Python SDK
- httpx for streaming support

The script performs:
1. Synchronous transcription using OpenAI-compatible API.
2. Streaming transcription using raw HTTP request to the vLLM server.
"""

21
import argparse
22
23
import asyncio

24
from openai import AsyncOpenAI, OpenAI
25
26
27
28

from vllm.assets.audio import AudioAsset


29
def sync_openai(audio_path: str, client: OpenAI, model: str):
30
31
32
    """
    Perform synchronous transcription using OpenAI-compatible API.
    """
33
    with open(audio_path, "rb") as f:
34
35
        transcription = client.audio.transcriptions.create(
            file=f,
36
            model=model,
37
38
            language="en",
            response_format="json",
39
40
41
42
43
            temperature=0.0,
            # Additional sampling params not provided by OpenAI API.
            extra_body=dict(
                seed=4419,
                repetition_penalty=1.3,
44
45
            ),
        )
46
        print("transcription result [sync]:", transcription.text)
47
48


49
async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: str):
50
    """
51
    Perform asynchronous transcription using OpenAI-compatible API.
52
    """
53
    print("\ntranscription result [stream]:", end=" ")
54
55
56
    with open(audio_path, "rb") as f:
        transcription = await client.audio.transcriptions.create(
            file=f,
57
            model=model,
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            language="en",
            response_format="json",
            temperature=0.0,
            # Additional sampling params not provided by OpenAI API.
            extra_body=dict(
                seed=420,
                top_p=0.6,
            ),
            stream=True,
        )
        async for chunk in transcription:
            if chunk.choices:
                content = chunk.choices[0].get("delta", {}).get("content")
                print(content, end="", flush=True)

73
74
75
    print()  # Final newline after stream ends


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def stream_api_response(audio_path: str, model: str, openai_api_base: str):
    """
    Perform streaming transcription using raw HTTP requests to the vLLM API server.
    """
    import json
    import os

    import requests

    api_url = f"{openai_api_base}/audio/transcriptions"
    headers = {"User-Agent": "Transcription-Client"}
    with open(audio_path, "rb") as f:
        files = {"file": (os.path.basename(audio_path), f)}
        data = {
            "stream": "true",
            "model": model,
            "language": "en",
            "response_format": "json",
        }

        print("\ntranscription result [stream]:", end=" ")
        response = requests.post(
            api_url, headers=headers, files=files, data=data, stream=True
        )
        for chunk in response.iter_lines(
            chunk_size=8192, decode_unicode=False, delimiter=b"\n"
        ):
            if chunk:
                data = chunk[len("data: ") :]
                data = json.loads(data.decode("utf-8"))
                data = data["choices"][0]
                delta = data["delta"]["content"]
                print(delta, end="", flush=True)

                finish_reason = data.get("finish_reason")
                if finish_reason is not None:
                    print(f"\n[Stream finished reason: {finish_reason}]")
                    break


def main(args):
117
118
119
120
121
122
123
124
125
126
127
    mary_had_lamb = str(AudioAsset("mary_had_lamb").get_local_path())
    winning_call = str(AudioAsset("winning_call").get_local_path())

    # Modify OpenAI's API key and API base to use vLLM's API server.
    openai_api_key = "EMPTY"
    openai_api_base = "http://localhost:8000/v1"
    client = OpenAI(
        api_key=openai_api_key,
        base_url=openai_api_base,
    )

128
129
130
131
132
133
    model = client.models.list().data[0].id
    print(f"Using model: {model}")

    # Run the synchronous function
    sync_openai(args.audio_path if args.audio_path else mary_had_lamb, client, model)

134
    # Run the asynchronous function
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    if "openai" in model:
        client = AsyncOpenAI(
            api_key=openai_api_key,
            base_url=openai_api_base,
        )
        asyncio.run(
            stream_openai_response(
                args.audio_path if args.audio_path else winning_call, client, model
            )
        )
    else:
        stream_api_response(
            args.audio_path if args.audio_path else winning_call,
            model,
            openai_api_base,
        )
151
152


153
if __name__ == "__main__":
154
155
156
157
158
159
160
161
162
163
164
165
    # setup argparser
    parser = argparse.ArgumentParser(
        description="OpenAI Transcription Client using vLLM API Server"
    )
    parser.add_argument(
        "--audio_path",
        type=str,
        default=None,
        help="The path to the audio file to transcribe.",
    )
    args = parser.parse_args()
    main(args)