"vllm/vscode:/vscode.git/clone" did not exist on "fcf2e3d7fcc9898b7a1b26bacea22753ab76f3a6"
speech_to_text.py 28.7 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import math
import time
7
from collections.abc import AsyncGenerator, Callable
8
from functools import cached_property
9
from typing import Literal, TypeAlias, TypeVar, cast
10
11
12

import numpy as np
from fastapi import Request
13
from transformers import PreTrainedTokenizerBase
14

15
import vllm.envs as envs
16
17
18
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
19
20
21
22
23
    DeltaMessage,
    ErrorResponse,
    RequestResponseMetadata,
    TranscriptionResponse,
    TranscriptionResponseStreamChoice,
24
25
    TranscriptionResponseVerbose,
    TranscriptionSegment,
26
27
28
    TranscriptionStreamResponse,
    TranslationResponse,
    TranslationResponseStreamChoice,
29
30
    TranslationResponseVerbose,
    TranslationSegment,
31
32
    TranslationStreamResponse,
    UsageInfo,
33
    VLLMValidationError,
34
35
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest
36
37
38
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
39
from vllm.model_executor.models import SupportsTranscription, supports_transcription
40
from vllm.outputs import RequestOutput
41
from vllm.tokenizers import get_tokenizer
42
from vllm.utils.import_utils import PlaceholderModule
43
44
45
46
47
48

try:
    import librosa
except ImportError:
    librosa = PlaceholderModule("librosa")  # type: ignore[assignment]

49
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
50
51
52
53
SpeechToTextResponseVerbose: TypeAlias = (
    TranscriptionResponseVerbose | TranslationResponseVerbose
)
SpeechToTextSegment: TypeAlias = TranscriptionSegment | TranslationSegment
54
T = TypeVar("T", bound=SpeechToTextResponse)
55
56
57
58
59
60
61
62
63
V = TypeVar("V", bound=SpeechToTextResponseVerbose)
S = TypeVar("S", bound=SpeechToTextSegment)

ResponseType: TypeAlias = (
    TranscriptionResponse
    | TranslationResponse
    | TranscriptionResponseVerbose
    | TranslationResponseVerbose
)
64
65
66
67
68

logger = init_logger(__name__)


class OpenAISpeechToText(OpenAIServing):
69
    """Base class for speech-to-text operations like transcription and
70
71
72
73
74
75
76
    translation."""

    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
77
        request_logger: RequestLogger | None,
78
79
        return_tokens_as_token_ids: bool = False,
        task_type: Literal["transcribe", "translate"] = "transcribe",
80
        log_error_stack: bool = False,
81
        enable_force_include_usage: bool = False,
82
    ):
83
84
85
86
87
88
89
90
91
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
            log_error_stack=log_error_stack,
        )

        self.default_sampling_params = self.model_config.get_diff_sampling_param()
92
93
        self.task_type = task_type

94
        self.asr_config = self.model_cls.get_speech_to_text_config(
95
            self.model_config, task_type
96
        )
97

98
99
        self.enable_force_include_usage = enable_force_include_usage

100
        self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
101
102
103
104
        if self.model_cls.supports_segment_timestamp:
            self.tokenizer = cast(
                PreTrainedTokenizerBase,
                get_tokenizer(
105
106
                    tokenizer_name=self.model_config.tokenizer,
                    tokenizer_mode=self.model_config.tokenizer_mode,
107
108
                ),
            )
109

110
111
112
        if self.default_sampling_params:
            logger.info(
                "Overwriting default completion sampling param with: %s",
113
114
                self.default_sampling_params,
            )
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        # Warm up audio preprocessing to avoid first-request latency
        self._warmup_audio_preprocessing()
        # Warm up input processor with dummy audio
        self._warmup_input_processor()

    def _warmup_audio_preprocessing(self) -> None:
        """Warm up audio processing libraries to avoid first-request latency.

        The first call to librosa functions (load, get_duration, mel-spectrogram)
        triggers JIT compilation and library initialization which can take ~7s.
        This method warms up these operations during server initialization.
        """
        # Skip warmup if librosa is not installed (optional dependency)
        if isinstance(librosa, PlaceholderModule):
            return

        # Skip warmup if model doesn't support transcription
        if not supports_transcription(self.model_cls):
            return

        try:
            warmup_start = time.perf_counter()
            logger.info("Warming up audio preprocessing libraries...")

            # Create a minimal dummy audio (1 second of silence at target sample rate)
            dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)

            # Warm up librosa.load by using librosa functions on the dummy data
            # This initializes FFTW, numba JIT, and other audio processing libraries
            _ = librosa.get_duration(y=dummy_audio, sr=self.asr_config.sample_rate)

            # Warm up mel-spectrogram computation with model-specific parameters
            from vllm.transformers_utils.processor import (
                cached_processor_from_config,
            )

            processor = cached_processor_from_config(self.model_config)
            feature_extractor = None
            if hasattr(processor, "feature_extractor"):
                feature_extractor = processor.feature_extractor
            elif hasattr(processor, "audio_processor"):
                # For models like GraniteSpeech that use audio_processor
                audio_proc = processor.audio_processor
                if hasattr(audio_proc, "feature_extractor"):
                    feature_extractor = audio_proc.feature_extractor
                # If audio_processor doesn't have feature_extractor,
                # skip mel-spectrogram warmup for these models

            if feature_extractor is not None:
                _ = librosa.feature.melspectrogram(
                    y=dummy_audio,
                    sr=self.asr_config.sample_rate,
                    n_mels=getattr(feature_extractor, "n_mels", 128),
                    n_fft=getattr(feature_extractor, "n_fft", 400),
                    hop_length=getattr(feature_extractor, "hop_length", 160),
                )

            warmup_elapsed = time.perf_counter() - warmup_start
            logger.info("Audio preprocessing warmup completed in %.2fs", warmup_elapsed)
        except Exception:
            # Don't fail initialization if warmup fails - log exception and continue
            logger.exception(
                "Audio preprocessing warmup failed (non-fatal): %s. "
                "First request may experience higher latency.",
            )

    def _warmup_input_processor(self) -> None:
        """Warm up input processor with dummy audio to avoid first-request latency.

        The first call to input_processor.process_inputs() with multimodal audio
        triggers multimodal processing initialization which can take ~2.5s.
        This method processes a dummy audio request to warm up the pipeline.
        """
        # Skip warmup if model doesn't support transcription
        if not supports_transcription(self.model_cls):
            return

        # Only warm up if model supports transcription methods
        if not hasattr(self.model_cls, "get_generation_prompt"):
            return

        try:
            from vllm.sampling_params import SamplingParams

            warmup_start = time.perf_counter()
            logger.info("Warming up multimodal input processor...")

            # Create minimal dummy audio (1 second of silence)
            dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)

            # Use the same method that _preprocess_speech_to_text uses
            # to create the prompt
            dummy_prompt = self.model_cls.get_generation_prompt(
                audio=dummy_audio,
                stt_config=self.asr_config,
                model_config=self.model_config,
                language="en",
                task_type=self.task_type,
                request_prompt="",
                to_language=None,
            )

            # Create minimal sampling params
            dummy_params = SamplingParams(
                max_tokens=1,
                temperature=0.0,
222
                skip_clone=True,  # Internal warmup, safe to skip clone
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            )

            # Process the dummy input through the input processor
            # This will trigger all the multimodal processing initialization
            _ = self.input_processor.process_inputs(
                request_id="warmup",
                prompt=dummy_prompt,
                params=dummy_params,
            )

            warmup_elapsed = time.perf_counter() - warmup_start
            logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
        except Exception:
            # Don't fail initialization if warmup fails - log warning and continue
            logger.exception(
                "Input processor warmup failed (non-fatal): %s. "
                "First request may experience higher latency."
            )

242
    @cached_property
243
    def model_cls(self) -> type[SupportsTranscription]:
244
        from vllm.model_executor.model_loader import get_model_cls
245

246
247
        model_cls = get_model_cls(self.model_config)
        return cast(type[SupportsTranscription], model_cls)
248

249
250
251
252
253
254
    async def _preprocess_speech_to_text(
        self,
        request: SpeechToTextRequest,
        audio_data: bytes,
    ) -> tuple[list[PromptType], float]:
        # Validate request
255
        language = self.model_cls.validate_language(request.language)
256
        # Skip to_language validation to avoid extra logging for Whisper.
257
258
259
260
261
        to_language = (
            self.model_cls.validate_language(request.to_language)
            if request.to_language
            else None
        )
262

263
        if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
264
265
266
267
268
            raise VLLMValidationError(
                "Maximum file size exceeded",
                parameter="audio_filesize_mb",
                value=len(audio_data) / 1024**2,
            )
269
270
271
272

        with io.BytesIO(audio_data) as bytes_:
            # NOTE resample to model SR here for efficiency. This is also a
            # pre-requisite for chunking, as it assumes Whisper SR.
273
            y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
274
275

        duration = librosa.get_duration(y=y, sr=sr)
276
277
278
279
        do_split_audio = (
            self.asr_config.allow_audio_chunking
            and duration > self.asr_config.max_audio_clip_s
        )
280
        chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
281
282
        prompts = []
        for chunk in chunks:
283
284
285
286
287
            # The model has control over the construction, as long as it
            # returns a valid PromptType.
            prompt = self.model_cls.get_generation_prompt(
                audio=chunk,
                stt_config=self.asr_config,
288
                model_config=self.model_config,
289
                language=language,
290
                task_type=self.task_type,
291
292
293
                request_prompt=request.prompt,
                to_language=to_language,
            )
294
295
            if request.response_format == "verbose_json":
                if not isinstance(prompt, dict):
296
297
298
299
300
                    raise VLLMValidationError(
                        "Expected prompt to be a dict",
                        parameter="prompt",
                        value=type(prompt).__name__,
                    )
301
302
303
                prompt_dict = cast(dict, prompt)
                decoder_prompt = prompt.get("decoder_prompt")
                if not isinstance(decoder_prompt, str):
304
305
306
307
                    raise VLLMValidationError(
                        "Expected decoder_prompt to be str",
                        parameter="decoder_prompt",
                        value=type(decoder_prompt).__name__,
308
309
310
311
                    )
                prompt_dict["decoder_prompt"] = decoder_prompt.replace(
                    "<|notimestamps|>", "<|0.00|>"
                )
312
            prompts.append(prompt)
313
314
        return prompts, duration

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    def _get_verbose_segments(
        self,
        tokens: tuple,
        request: SpeechToTextRequest,
        segment_class: type[SpeechToTextSegment],
        start_time: float = 0,
    ) -> list[SpeechToTextSegment]:
        """
        Convert tokens to verbose segments.

        This method expects the model to produce
        timestamps as tokens (similar to Whisper).
        If the tokens do not include timestamp information,
        the segments may not be generated correctly.

        Note: Fields like avg_logprob, compression_ratio,
        and no_speech_prob are not supported
        in this implementation and will be None. See docs for details.
        """
        BASE_OFFSET = 0.02
        init_token = self.tokenizer.encode("<|0.00|>", add_special_tokens=False)[0]
        if tokens[-1] == self.tokenizer.eos_token_id:
            tokens = tokens[:-1]

        tokens_with_start = (init_token,) + tokens
        segments: list[SpeechToTextSegment] = []
        last_timestamp_start = 0

        if tokens_with_start[-2] < init_token and tokens_with_start[-1] >= init_token:
            tokens_with_start = tokens_with_start + (tokens_with_start[-1],)
        for idx, token in enumerate(tokens_with_start):
            # Timestamp tokens (e.g., <|0.00|>) are assumed to be sorted.
            # If the ordering is violated, this slicing may produce incorrect results.
            if (
                token >= init_token
                and idx != 0
                and tokens_with_start[idx - 1] >= init_token
            ):
                sliced_timestamp_tokens = tokens_with_start[last_timestamp_start:idx]
                start_timestamp = sliced_timestamp_tokens[0] - init_token
                end_timestamp = sliced_timestamp_tokens[-1] - init_token

                casting_segment = cast(
                    SpeechToTextSegment,
                    segment_class(
                        id=len(segments),
                        seek=start_time,
                        start=start_time + BASE_OFFSET * start_timestamp,
                        end=start_time + BASE_OFFSET * end_timestamp,
                        temperature=request.temperature,
                        text=self.tokenizer.decode(sliced_timestamp_tokens[1:-1]),
                        tokens=sliced_timestamp_tokens[1:-1],
                    ),
                )
                segments.append(casting_segment)
                last_timestamp_start = idx
        return segments

373
374
375
376
377
    async def _create_speech_to_text(
        self,
        audio_data: bytes,
        request: SpeechToTextRequest,
        raw_request: Request,
378
        response_class: type[T | V],
379
        stream_generator_method: Callable[..., AsyncGenerator[str, None]],
380
    ) -> T | V | AsyncGenerator[str, None] | ErrorResponse:
381
        """Base method for speech-to-text operations like transcription and
382
383
384
385
386
387
388
389
390
391
392
        translation."""
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

393
        if request.response_format not in ["text", "json", "verbose_json"]:
394
            return self.create_error_response(
395
396
                ("Currently only support response_format")
                + ("`text`, `json` or `verbose_json`")
397
            )
398

399
400
401
402
403
404
405
406
407
408
409
410
        if (
            request.response_format == "verbose_json"
            and not self.model_cls.supports_segment_timestamp
        ):
            return self.create_error_response(
                f"Currently do not support verbose_json for {request.model}"
            )

        if request.response_format == "verbose_json" and request.stream:
            return self.create_error_response(
                "verbose_json format doesn't support streaming case"
            )
411
412
413
414
415
416
417
        request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

        try:
418
            lora_request = self._maybe_get_adapters(request)
419
420
421
422
423
424
425
426

            prompts, duration_s = await self._preprocess_speech_to_text(
                request=request,
                audio_data=audio_data,
            )

        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
427
            return self.create_error_response(e)
428

429
        list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
430
431
432
        try:
            # Unlike most decoder-only models, whisper generation length is not
            # constrained by the size of the input audio, which is mapped to a
433
434
435
436
437
438
439
440
            # fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
            # generated by respecting the extra completion tokens arg.
            if request.max_completion_tokens is None:
                default_max_tokens = self.model_config.max_model_len
            else:
                default_max_tokens = min(
                    self.model_config.max_model_len, request.max_completion_tokens
                )
441
            sampling_params = request.to_sampling_params(
442
443
                default_max_tokens, self.default_sampling_params
            )
444
445
446

            self._log_inputs(
                request_id,
447
448
                # It will not display special tokens like <|startoftranscript|>
                request.prompt,
449
                params=sampling_params,
450
                lora_request=lora_request,
451
            )
452
453
454
455
456

            list_result_generator = [
                self.engine_client.generate(
                    prompt,
                    sampling_params,
457
                    f"{request_id}_{i}",
458
                    lora_request=lora_request,
459
                )
460
                for i, prompt in enumerate(prompts)
461
462
            ]
        except ValueError as e:
463
            return self.create_error_response(e)
464
465

        if request.stream:
466
467
468
            return stream_generator_method(
                request, list_result_generator, request_id, request_metadata, duration_s
            )
469
        # Non-streaming response.
470
471
        total_segments = []
        text_parts = []
472
473
        try:
            assert list_result_generator is not None
474
475
476
477
478
            segments_types: dict[str, type[SpeechToTextSegment]] = {
                "transcribe": TranscriptionSegment,
                "translate": TranslationSegment,
            }
            segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
479
            text = ""
480
481
482
483
484
            chunk_size_in_s = self.asr_config.max_audio_clip_s
            if chunk_size_in_s is None:
                assert len(list_result_generator) == 1, (
                    "`max_audio_clip_s` is set to None, audio cannot be chunked"
                )
485
            for idx, result_generator in enumerate(list_result_generator):
486
487
488
                start_time = (
                    float(idx * chunk_size_in_s) if chunk_size_in_s is not None else 0.0
                )
489
                async for op in result_generator:
490
491
492
493
494
495
                    if request.response_format == "verbose_json":
                        segments: list[SpeechToTextSegment] = (
                            self._get_verbose_segments(
                                tokens=tuple(op.outputs[0].token_ids),
                                segment_class=segment_class,
                                request=request,
496
                                start_time=start_time,
497
498
499
500
501
502
503
504
                            )
                        )

                        total_segments.extend(segments)
                        text_parts.extend([seg.text for seg in segments])
                    else:
                        text_parts.append(op.outputs[0].text)
            text = "".join(text_parts)
505
            if self.task_type == "transcribe":
506
                final_response: ResponseType
507
508
509
510
511
512
                # add usage in TranscriptionResponse.
                usage = {
                    "type": "duration",
                    # rounded up as per openAI specs
                    "seconds": int(math.ceil(duration_s)),
                }
513
514
515
516
517
518
519
520
521
522
523
524
525
526
                if request.response_format != "verbose_json":
                    final_response = cast(
                        T, TranscriptionResponse(text=text, usage=usage)
                    )
                else:
                    final_response = cast(
                        V,
                        TranscriptionResponseVerbose(
                            text=text,
                            language=request.language,
                            duration=str(duration_s),
                            segments=total_segments,
                        ),
                    )
527
528
            else:
                # no usage in response for translation task
529
530
531
532
533
534
535
536
537
538
539
540
                if request.response_format != "verbose_json":
                    final_response = cast(T, TranslationResponse(text=text))
                else:
                    final_response = cast(
                        V,
                        TranslationResponseVerbose(
                            text=text,
                            language=request.language,
                            duration=str(duration_s),
                            segments=total_segments,
                        ),
                    )
541
            return final_response
542
543
544
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
        except ValueError as e:
545
            return self.create_error_response(e)
546
547
548
549
550
551
552
553
554

    async def _speech_to_text_stream_generator(
        self,
        request: SpeechToTextRequest,
        list_result_generator: list[AsyncGenerator[RequestOutput, None]],
        request_id: str,
        request_metadata: RequestResponseMetadata,
        audio_duration_s: float,
        chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
555
556
557
558
        response_stream_choice_class: type[TranscriptionResponseStreamChoice]
        | type[TranslationResponseStreamChoice],
        stream_response_class: type[TranscriptionStreamResponse]
        | type[TranslationStreamResponse],
559
560
561
562
563
564
565
    ) -> AsyncGenerator[str, None]:
        created_time = int(time.time())
        model_name = request.model

        completion_tokens = 0
        num_prompt_tokens = 0

566
        include_usage = self.enable_force_include_usage or request.stream_include_usage
567
568
569
        include_continuous_usage = (
            request.stream_continuous_usage_stats
            if include_usage and request.stream_continuous_usage_stats
570
            else False
571
        )
572
573
574
575
576
577

        try:
            for result_generator in list_result_generator:
                async for res in result_generator:
                    # On first result.
                    if res.prompt_token_ids is not None:
578
579
                        num_prompt_tokens = len(res.prompt_token_ids)
                        if audio_tokens := self.model_cls.get_num_audio_tokens(
580
                            audio_duration_s, self.asr_config, self.model_config
581
                        ):
582
                            num_prompt_tokens += audio_tokens
583
584
585
586
587
588
589
590
591
592
593
594
595
596

                    # We need to do it here, because if there are exceptions in
                    # the result_generator, it needs to be sent as the FIRST
                    # response (by the try...catch).

                    # Just one output (n=1) supported.
                    assert len(res.outputs) == 1
                    output = res.outputs[0]

                    delta_message = DeltaMessage(content=output.text)
                    completion_tokens += len(output.token_ids)

                    if output.finish_reason is None:
                        # Still generating, send delta update.
597
                        choice_data = response_stream_choice_class(delta=delta_message)
598
599
600
601
602
                    else:
                        # Model is finished generating.
                        choice_data = response_stream_choice_class(
                            delta=delta_message,
                            finish_reason=output.finish_reason,
603
604
                            stop_reason=output.stop_reason,
                        )
605

606
607
608
609
610
611
612
                    chunk = stream_response_class(
                        id=request_id,
                        object=chunk_object_type,
                        created=created_time,
                        choices=[choice_data],
                        model=model_name,
                    )
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627

                    # handle usage stats if requested & if continuous
                    if include_continuous_usage:
                        chunk.usage = UsageInfo(
                            prompt_tokens=num_prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=num_prompt_tokens + completion_tokens,
                        )

                    data = chunk.model_dump_json(exclude_unset=True)
                    yield f"data: {data}\n\n"

            # Once the final token is handled, if stream_options.include_usage
            # is sent, send the usage.
            if include_usage:
628
629
630
631
632
                final_usage = UsageInfo(
                    prompt_tokens=num_prompt_tokens,
                    completion_tokens=completion_tokens,
                    total_tokens=num_prompt_tokens + completion_tokens,
                )
633
634
635
636
637
638
639

                final_usage_chunk = stream_response_class(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
640
641
642
643
644
                    usage=final_usage,
                )
                final_usage_data = final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True
                )
645
646
647
648
649
650
                yield f"data: {final_usage_data}\n\n"

            # report to FastAPI middleware aggregate usage across all choices
            request_metadata.final_usage_info = UsageInfo(
                prompt_tokens=num_prompt_tokens,
                completion_tokens=completion_tokens,
651
652
                total_tokens=num_prompt_tokens + completion_tokens,
            )
653
654
655

        except Exception as e:
            logger.exception("Error in %s stream generator.", self.task_type)
656
            data = self.create_streaming_error_response(e)
657
658
659
660
            yield f"data: {data}\n\n"
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

661
662
663
    def _split_audio(
        self, audio_data: np.ndarray, sample_rate: int
    ) -> list[np.ndarray]:
664
665
666
667
        assert self.asr_config.max_audio_clip_s is not None, (
            f"{self.asr_config.max_audio_clip_s=} cannot be None to"
            " split audio into chunks."
        )
668
669
        chunk_size = sample_rate * self.asr_config.max_audio_clip_s
        overlap_size = sample_rate * self.asr_config.overlap_chunk_second
670
671
672
673
674
675
676
677
678
679
680
        chunks = []
        i = 0
        while i < audio_data.shape[-1]:
            if i + chunk_size >= audio_data.shape[-1]:
                # handle last chunk
                chunks.append(audio_data[..., i:])
                break

            # Find the best split point in the overlap region
            search_start = i + chunk_size - overlap_size
            search_end = min(i + chunk_size, audio_data.shape[-1])
681
            split_point = self._find_split_point(audio_data, search_start, search_end)
682
683
684
685
686
687

            # Extract chunk up to the split point
            chunks.append(audio_data[..., i:split_point])
            i = split_point
        return chunks

688
689
    def _find_split_point(self, wav: np.ndarray, start_idx: int, end_idx: int) -> int:
        """Find the best point to split audio by
690
691
692
693
694
695
696
697
698
699
700
701
702
        looking for silence or low amplitude.
        Args:
            wav: Audio tensor [1, T]
            start_idx: Start index of search region
            end_idx: End index of search region
        Returns:
            Index of best splitting point
        """
        segment = wav[start_idx:end_idx]

        # Calculate RMS energy in small windows
        min_energy = math.inf
        quietest_idx = 0
703
704
705
        min_energy_window = self.asr_config.min_energy_split_window_size
        assert min_energy_window is not None
        for i in range(0, len(segment) - min_energy_window, min_energy_window):
706
707
            window = segment[i : i + min_energy_window]
            energy = (window**2).mean() ** 0.5
708
709
710
711
            if energy < min_energy:
                quietest_idx = i + start_idx
                min_energy = energy
        return quietest_idx