serving_transcription.py 5.47 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import AsyncGenerator
4
from typing import Optional, Union
5
6
7
8
9
10

from fastapi import Request

from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
11
from vllm.entrypoints.openai.protocol import (
12
    ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
13
    TranscriptionResponse, TranscriptionResponseStreamChoice,
14
15
    TranscriptionStreamResponse, TranslationRequest, TranslationResponse,
    TranslationResponseStreamChoice, TranslationStreamResponse)
16
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
17
from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText
18
19
20
21
22
23
from vllm.logger import init_logger
from vllm.outputs import RequestOutput

logger = init_logger(__name__)


24
25
class OpenAIServingTranscription(OpenAISpeechToText):
    """Handles transcription requests."""
26
27
28
29
30
31
32
33
34

    def __init__(
        self,
        engine_client: EngineClient,
        model_config: ModelConfig,
        models: OpenAIServingModels,
        *,
        request_logger: Optional[RequestLogger],
        return_tokens_as_token_ids: bool = False,
35
        log_error_stack: bool = False,
36
37
38
39
40
    ):
        super().__init__(engine_client=engine_client,
                         model_config=model_config,
                         models=models,
                         request_logger=request_logger,
41
                         return_tokens_as_token_ids=return_tokens_as_token_ids,
42
43
                         task_type="transcribe",
                         log_error_stack=log_error_stack)
44
45
46
47

    async def create_transcription(
        self, audio_data: bytes, request: TranscriptionRequest,
        raw_request: Request
48
    ) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
49
50
51
52
53
54
               ErrorResponse]:
        """Transcription API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/audio/createTranscription
        for the API specification. This API mimics the OpenAI transcription API.
        """
55
56
57
58
59
60
61
        return await self._create_speech_to_text(
            audio_data=audio_data,
            request=request,
            raw_request=raw_request,
            response_class=TranscriptionResponse,
            stream_generator_method=self.transcription_stream_generator,
        )
62
63
64

    async def transcription_stream_generator(
            self, request: TranscriptionRequest,
65
            result_generator: list[AsyncGenerator[RequestOutput, None]],
66
67
            request_id: str, request_metadata: RequestResponseMetadata,
            audio_duration_s: float) -> AsyncGenerator[str, None]:
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        generator = self._speech_to_text_stream_generator(
            request=request,
            list_result_generator=result_generator,
            request_id=request_id,
            request_metadata=request_metadata,
            audio_duration_s=audio_duration_s,
            chunk_object_type="transcription.chunk",
            response_stream_choice_class=TranscriptionResponseStreamChoice,
            stream_response_class=TranscriptionStreamResponse,
        )
        async for chunk in generator:
            yield chunk


class OpenAIServingTranslation(OpenAISpeechToText):
    """Handles translation requests."""
84

85
86
87
88
89
90
91
92
    def __init__(
        self,
        engine_client: EngineClient,
        model_config: ModelConfig,
        models: OpenAIServingModels,
        *,
        request_logger: Optional[RequestLogger],
        return_tokens_as_token_ids: bool = False,
93
        log_error_stack: bool = False,
94
95
96
97
98
99
    ):
        super().__init__(engine_client=engine_client,
                         model_config=model_config,
                         models=models,
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids,
100
101
                         task_type="translate",
                         log_error_stack=log_error_stack)
102

103
104
105
106
107
    async def create_translation(
        self, audio_data: bytes, request: TranslationRequest,
        raw_request: Request
    ) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]:
        """Translation API similar to OpenAI's API.
108

109
110
        See https://platform.openai.com/docs/api-reference/audio/createTranslation
        for the API specification. This API mimics the OpenAI translation API.
111
        """
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        return await self._create_speech_to_text(
            audio_data=audio_data,
            request=request,
            raw_request=raw_request,
            response_class=TranslationResponse,
            stream_generator_method=self.translation_stream_generator,
        )

    async def translation_stream_generator(
            self, request: TranslationRequest,
            result_generator: list[AsyncGenerator[RequestOutput, None]],
            request_id: str, request_metadata: RequestResponseMetadata,
            audio_duration_s: float) -> AsyncGenerator[str, None]:
        generator = self._speech_to_text_stream_generator(
            request=request,
            list_result_generator=result_generator,
            request_id=request_id,
            request_metadata=request_metadata,
            audio_duration_s=audio_duration_s,
            chunk_object_type="translation.chunk",
            response_stream_choice_class=TranslationResponseStreamChoice,
            stream_response_class=TranslationStreamResponse,
        )
        async for chunk in generator:
            yield chunk