"tests/kernels/attention/test_attention.py" did not exist on "03ffd0a02251e10c1aa14fca8cb0ab1e4e40b886"
speech_to_text.py 27.9 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import math
import time
7
from collections.abc import AsyncGenerator, Callable
8
from functools import cached_property
9
from typing import Literal, TypeAlias, TypeVar, cast
10
11
12

import numpy as np
from fastapi import Request
13
from transformers import PreTrainedTokenizerBase
14

15
import vllm.envs as envs
16
17
18
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
19
20
21
22
23
    DeltaMessage,
    ErrorResponse,
    RequestResponseMetadata,
    TranscriptionResponse,
    TranscriptionResponseStreamChoice,
24
25
    TranscriptionResponseVerbose,
    TranscriptionSegment,
26
27
28
    TranscriptionStreamResponse,
    TranslationResponse,
    TranslationResponseStreamChoice,
29
30
    TranslationResponseVerbose,
    TranslationSegment,
31
32
33
34
    TranslationStreamResponse,
    UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest
35
36
37
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
38
from vllm.model_executor.models import SupportsTranscription, supports_transcription
39
from vllm.outputs import RequestOutput
40
from vllm.tokenizers import get_tokenizer
41
from vllm.utils.import_utils import PlaceholderModule
42
43
44
45
46
47

try:
    import librosa
except ImportError:
    librosa = PlaceholderModule("librosa")  # type: ignore[assignment]

48
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
49
50
51
52
SpeechToTextResponseVerbose: TypeAlias = (
    TranscriptionResponseVerbose | TranslationResponseVerbose
)
SpeechToTextSegment: TypeAlias = TranscriptionSegment | TranslationSegment
53
T = TypeVar("T", bound=SpeechToTextResponse)
54
55
56
57
58
59
60
61
62
V = TypeVar("V", bound=SpeechToTextResponseVerbose)
S = TypeVar("S", bound=SpeechToTextSegment)

ResponseType: TypeAlias = (
    TranscriptionResponse
    | TranslationResponse
    | TranscriptionResponseVerbose
    | TranslationResponseVerbose
)
63
64
65
66
67

logger = init_logger(__name__)


class OpenAISpeechToText(OpenAIServing):
68
    """Base class for speech-to-text operations like transcription and
69
70
71
72
73
74
75
    translation."""

    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
76
        request_logger: RequestLogger | None,
77
78
        return_tokens_as_token_ids: bool = False,
        task_type: Literal["transcribe", "translate"] = "transcribe",
79
        log_error_stack: bool = False,
80
        enable_force_include_usage: bool = False,
81
    ):
82
83
84
85
86
87
88
89
90
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
            log_error_stack=log_error_stack,
        )

        self.default_sampling_params = self.model_config.get_diff_sampling_param()
91
92
        self.task_type = task_type

93
        self.asr_config = self.model_cls.get_speech_to_text_config(
94
            self.model_config, task_type
95
        )
96

97
98
        self.enable_force_include_usage = enable_force_include_usage

99
        self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
100
101
102
103
        if self.model_cls.supports_segment_timestamp:
            self.tokenizer = cast(
                PreTrainedTokenizerBase,
                get_tokenizer(
104
105
                    tokenizer_name=self.model_config.tokenizer,
                    tokenizer_mode=self.model_config.tokenizer_mode,
106
107
                ),
            )
108

109
110
111
        if self.default_sampling_params:
            logger.info(
                "Overwriting default completion sampling param with: %s",
112
113
                self.default_sampling_params,
            )
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        # Warm up audio preprocessing to avoid first-request latency
        self._warmup_audio_preprocessing()
        # Warm up input processor with dummy audio
        self._warmup_input_processor()

    def _warmup_audio_preprocessing(self) -> None:
        """Warm up audio processing libraries to avoid first-request latency.

        The first call to librosa functions (load, get_duration, mel-spectrogram)
        triggers JIT compilation and library initialization which can take ~7s.
        This method warms up these operations during server initialization.
        """
        # Skip warmup if librosa is not installed (optional dependency)
        if isinstance(librosa, PlaceholderModule):
            return

        # Skip warmup if model doesn't support transcription
        if not supports_transcription(self.model_cls):
            return

        try:
            warmup_start = time.perf_counter()
            logger.info("Warming up audio preprocessing libraries...")

            # Create a minimal dummy audio (1 second of silence at target sample rate)
            dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)

            # Warm up librosa.load by using librosa functions on the dummy data
            # This initializes FFTW, numba JIT, and other audio processing libraries
            _ = librosa.get_duration(y=dummy_audio, sr=self.asr_config.sample_rate)

            # Warm up mel-spectrogram computation with model-specific parameters
            from vllm.transformers_utils.processor import (
                cached_processor_from_config,
            )

            processor = cached_processor_from_config(self.model_config)
            feature_extractor = None
            if hasattr(processor, "feature_extractor"):
                feature_extractor = processor.feature_extractor
            elif hasattr(processor, "audio_processor"):
                # For models like GraniteSpeech that use audio_processor
                audio_proc = processor.audio_processor
                if hasattr(audio_proc, "feature_extractor"):
                    feature_extractor = audio_proc.feature_extractor
                # If audio_processor doesn't have feature_extractor,
                # skip mel-spectrogram warmup for these models

            if feature_extractor is not None:
                _ = librosa.feature.melspectrogram(
                    y=dummy_audio,
                    sr=self.asr_config.sample_rate,
                    n_mels=getattr(feature_extractor, "n_mels", 128),
                    n_fft=getattr(feature_extractor, "n_fft", 400),
                    hop_length=getattr(feature_extractor, "hop_length", 160),
                )

            warmup_elapsed = time.perf_counter() - warmup_start
            logger.info("Audio preprocessing warmup completed in %.2fs", warmup_elapsed)
        except Exception:
            # Don't fail initialization if warmup fails - log exception and continue
            logger.exception(
                "Audio preprocessing warmup failed (non-fatal): %s. "
                "First request may experience higher latency.",
            )

    def _warmup_input_processor(self) -> None:
        """Warm up input processor with dummy audio to avoid first-request latency.

        The first call to input_processor.process_inputs() with multimodal audio
        triggers multimodal processing initialization which can take ~2.5s.
        This method processes a dummy audio request to warm up the pipeline.
        """
        # Skip warmup if model doesn't support transcription
        if not supports_transcription(self.model_cls):
            return

        # Only warm up if model supports transcription methods
        if not hasattr(self.model_cls, "get_generation_prompt"):
            return

        try:
            from vllm.sampling_params import SamplingParams

            warmup_start = time.perf_counter()
            logger.info("Warming up multimodal input processor...")

            # Create minimal dummy audio (1 second of silence)
            dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)

            # Use the same method that _preprocess_speech_to_text uses
            # to create the prompt
            dummy_prompt = self.model_cls.get_generation_prompt(
                audio=dummy_audio,
                stt_config=self.asr_config,
                model_config=self.model_config,
                language="en",
                task_type=self.task_type,
                request_prompt="",
                to_language=None,
            )

            # Create minimal sampling params
            dummy_params = SamplingParams(
                max_tokens=1,
                temperature=0.0,
            )

            # Process the dummy input through the input processor
            # This will trigger all the multimodal processing initialization
            _ = self.input_processor.process_inputs(
                request_id="warmup",
                prompt=dummy_prompt,
                params=dummy_params,
            )

            warmup_elapsed = time.perf_counter() - warmup_start
            logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
        except Exception:
            # Don't fail initialization if warmup fails - log warning and continue
            logger.exception(
                "Input processor warmup failed (non-fatal): %s. "
                "First request may experience higher latency."
            )

240
    @cached_property
241
    def model_cls(self) -> type[SupportsTranscription]:
242
        from vllm.model_executor.model_loader import get_model_cls
243

244
245
        model_cls = get_model_cls(self.model_config)
        return cast(type[SupportsTranscription], model_cls)
246

247
248
249
250
251
252
    async def _preprocess_speech_to_text(
        self,
        request: SpeechToTextRequest,
        audio_data: bytes,
    ) -> tuple[list[PromptType], float]:
        # Validate request
253
        language = self.model_cls.validate_language(request.language)
254
        # Skip to_language validation to avoid extra logging for Whisper.
255
256
257
258
259
        to_language = (
            self.model_cls.validate_language(request.to_language)
            if request.to_language
            else None
        )
260

261
        if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
262
263
264
265
266
            raise ValueError("Maximum file size exceeded.")

        with io.BytesIO(audio_data) as bytes_:
            # NOTE resample to model SR here for efficiency. This is also a
            # pre-requisite for chunking, as it assumes Whisper SR.
267
            y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
268
269

        duration = librosa.get_duration(y=y, sr=sr)
270
271
272
273
        do_split_audio = (
            self.asr_config.allow_audio_chunking
            and duration > self.asr_config.max_audio_clip_s
        )
274
        chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
275
276
        prompts = []
        for chunk in chunks:
277
278
279
280
281
            # The model has control over the construction, as long as it
            # returns a valid PromptType.
            prompt = self.model_cls.get_generation_prompt(
                audio=chunk,
                stt_config=self.asr_config,
282
                model_config=self.model_config,
283
                language=language,
284
                task_type=self.task_type,
285
286
287
                request_prompt=request.prompt,
                to_language=to_language,
            )
288
289
290
291
292
293
294
295
296
297
298
299
            if request.response_format == "verbose_json":
                if not isinstance(prompt, dict):
                    raise ValueError(f"Expected prompt to be a dict,got {type(prompt)}")
                prompt_dict = cast(dict, prompt)
                decoder_prompt = prompt.get("decoder_prompt")
                if not isinstance(decoder_prompt, str):
                    raise ValueError(
                        f"Expected decoder_prompt to bestr, got {type(decoder_prompt)}"
                    )
                prompt_dict["decoder_prompt"] = decoder_prompt.replace(
                    "<|notimestamps|>", "<|0.00|>"
                )
300
            prompts.append(prompt)
301
302
        return prompts, duration

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    def _get_verbose_segments(
        self,
        tokens: tuple,
        request: SpeechToTextRequest,
        segment_class: type[SpeechToTextSegment],
        start_time: float = 0,
    ) -> list[SpeechToTextSegment]:
        """
        Convert tokens to verbose segments.

        This method expects the model to produce
        timestamps as tokens (similar to Whisper).
        If the tokens do not include timestamp information,
        the segments may not be generated correctly.

        Note: Fields like avg_logprob, compression_ratio,
        and no_speech_prob are not supported
        in this implementation and will be None. See docs for details.
        """
        BASE_OFFSET = 0.02
        init_token = self.tokenizer.encode("<|0.00|>", add_special_tokens=False)[0]
        if tokens[-1] == self.tokenizer.eos_token_id:
            tokens = tokens[:-1]

        tokens_with_start = (init_token,) + tokens
        segments: list[SpeechToTextSegment] = []
        last_timestamp_start = 0

        if tokens_with_start[-2] < init_token and tokens_with_start[-1] >= init_token:
            tokens_with_start = tokens_with_start + (tokens_with_start[-1],)
        for idx, token in enumerate(tokens_with_start):
            # Timestamp tokens (e.g., <|0.00|>) are assumed to be sorted.
            # If the ordering is violated, this slicing may produce incorrect results.
            if (
                token >= init_token
                and idx != 0
                and tokens_with_start[idx - 1] >= init_token
            ):
                sliced_timestamp_tokens = tokens_with_start[last_timestamp_start:idx]
                start_timestamp = sliced_timestamp_tokens[0] - init_token
                end_timestamp = sliced_timestamp_tokens[-1] - init_token

                casting_segment = cast(
                    SpeechToTextSegment,
                    segment_class(
                        id=len(segments),
                        seek=start_time,
                        start=start_time + BASE_OFFSET * start_timestamp,
                        end=start_time + BASE_OFFSET * end_timestamp,
                        temperature=request.temperature,
                        text=self.tokenizer.decode(sliced_timestamp_tokens[1:-1]),
                        tokens=sliced_timestamp_tokens[1:-1],
                    ),
                )
                segments.append(casting_segment)
                last_timestamp_start = idx
        return segments

361
362
363
364
365
    async def _create_speech_to_text(
        self,
        audio_data: bytes,
        request: SpeechToTextRequest,
        raw_request: Request,
366
        response_class: type[T | V],
367
        stream_generator_method: Callable[..., AsyncGenerator[str, None]],
368
    ) -> T | V | AsyncGenerator[str, None] | ErrorResponse:
369
        """Base method for speech-to-text operations like transcription and
370
371
372
373
374
375
376
377
378
379
380
        translation."""
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

381
        if request.response_format not in ["text", "json", "verbose_json"]:
382
            return self.create_error_response(
383
384
                ("Currently only support response_format")
                + ("`text`, `json` or `verbose_json`")
385
            )
386

387
388
389
390
391
392
393
394
395
396
397
398
        if (
            request.response_format == "verbose_json"
            and not self.model_cls.supports_segment_timestamp
        ):
            return self.create_error_response(
                f"Currently do not support verbose_json for {request.model}"
            )

        if request.response_format == "verbose_json" and request.stream:
            return self.create_error_response(
                "verbose_json format doesn't support streaming case"
            )
399
400
401
402
403
404
405
        request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

        try:
406
            lora_request = self._maybe_get_adapters(request)
407
408
409
410
411
412
413
414
415
416

            prompts, duration_s = await self._preprocess_speech_to_text(
                request=request,
                audio_data=audio_data,
            )

        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))

417
        list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
418
419
420
        try:
            # Unlike most decoder-only models, whisper generation length is not
            # constrained by the size of the input audio, which is mapped to a
421
422
423
424
425
426
427
428
            # fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
            # generated by respecting the extra completion tokens arg.
            if request.max_completion_tokens is None:
                default_max_tokens = self.model_config.max_model_len
            else:
                default_max_tokens = min(
                    self.model_config.max_model_len, request.max_completion_tokens
                )
429
            sampling_params = request.to_sampling_params(
430
431
                default_max_tokens, self.default_sampling_params
            )
432
433
434

            self._log_inputs(
                request_id,
435
436
                # It will not display special tokens like <|startoftranscript|>
                request.prompt,
437
                params=sampling_params,
438
                lora_request=lora_request,
439
            )
440
441
442
443
444

            list_result_generator = [
                self.engine_client.generate(
                    prompt,
                    sampling_params,
445
                    f"{request_id}_{i}",
446
                    lora_request=lora_request,
447
                )
448
                for i, prompt in enumerate(prompts)
449
450
451
452
453
454
            ]
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

        if request.stream:
455
456
457
            return stream_generator_method(
                request, list_result_generator, request_id, request_metadata, duration_s
            )
458
        # Non-streaming response.
459
460
        total_segments = []
        text_parts = []
461
462
        try:
            assert list_result_generator is not None
463
464
465
466
467
            segments_types: dict[str, type[SpeechToTextSegment]] = {
                "transcribe": TranscriptionSegment,
                "translate": TranslationSegment,
            }
            segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
468
            text = ""
469
            for idx, result_generator in enumerate(list_result_generator):
470
                async for op in result_generator:
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
                    if request.response_format == "verbose_json":
                        segments: list[SpeechToTextSegment] = (
                            self._get_verbose_segments(
                                tokens=tuple(op.outputs[0].token_ids),
                                segment_class=segment_class,
                                request=request,
                                start_time=idx * self.asr_config.max_audio_clip_s,
                            )
                        )

                        total_segments.extend(segments)
                        text_parts.extend([seg.text for seg in segments])
                    else:
                        text_parts.append(op.outputs[0].text)
            text = "".join(text_parts)
486
            if self.task_type == "transcribe":
487
                final_response: ResponseType
488
489
490
491
492
493
                # add usage in TranscriptionResponse.
                usage = {
                    "type": "duration",
                    # rounded up as per openAI specs
                    "seconds": int(math.ceil(duration_s)),
                }
494
495
496
497
498
499
500
501
502
503
504
505
506
507
                if request.response_format != "verbose_json":
                    final_response = cast(
                        T, TranscriptionResponse(text=text, usage=usage)
                    )
                else:
                    final_response = cast(
                        V,
                        TranscriptionResponseVerbose(
                            text=text,
                            language=request.language,
                            duration=str(duration_s),
                            segments=total_segments,
                        ),
                    )
508
509
            else:
                # no usage in response for translation task
510
511
512
513
514
515
516
517
518
519
520
521
                if request.response_format != "verbose_json":
                    final_response = cast(T, TranslationResponse(text=text))
                else:
                    final_response = cast(
                        V,
                        TranslationResponseVerbose(
                            text=text,
                            language=request.language,
                            duration=str(duration_s),
                            segments=total_segments,
                        ),
                    )
522
            return final_response
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _speech_to_text_stream_generator(
        self,
        request: SpeechToTextRequest,
        list_result_generator: list[AsyncGenerator[RequestOutput, None]],
        request_id: str,
        request_metadata: RequestResponseMetadata,
        audio_duration_s: float,
        chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
537
538
539
540
        response_stream_choice_class: type[TranscriptionResponseStreamChoice]
        | type[TranslationResponseStreamChoice],
        stream_response_class: type[TranscriptionStreamResponse]
        | type[TranslationStreamResponse],
541
542
543
544
545
546
547
    ) -> AsyncGenerator[str, None]:
        created_time = int(time.time())
        model_name = request.model

        completion_tokens = 0
        num_prompt_tokens = 0

548
        include_usage = self.enable_force_include_usage or request.stream_include_usage
549
550
551
        include_continuous_usage = (
            request.stream_continuous_usage_stats
            if include_usage and request.stream_continuous_usage_stats
552
            else False
553
        )
554
555
556
557
558
559

        try:
            for result_generator in list_result_generator:
                async for res in result_generator:
                    # On first result.
                    if res.prompt_token_ids is not None:
560
561
                        num_prompt_tokens = len(res.prompt_token_ids)
                        if audio_tokens := self.model_cls.get_num_audio_tokens(
562
                            audio_duration_s, self.asr_config, self.model_config
563
                        ):
564
                            num_prompt_tokens += audio_tokens
565
566
567
568
569
570
571
572
573
574
575
576
577
578

                    # We need to do it here, because if there are exceptions in
                    # the result_generator, it needs to be sent as the FIRST
                    # response (by the try...catch).

                    # Just one output (n=1) supported.
                    assert len(res.outputs) == 1
                    output = res.outputs[0]

                    delta_message = DeltaMessage(content=output.text)
                    completion_tokens += len(output.token_ids)

                    if output.finish_reason is None:
                        # Still generating, send delta update.
579
                        choice_data = response_stream_choice_class(delta=delta_message)
580
581
582
583
584
                    else:
                        # Model is finished generating.
                        choice_data = response_stream_choice_class(
                            delta=delta_message,
                            finish_reason=output.finish_reason,
585
586
                            stop_reason=output.stop_reason,
                        )
587

588
589
590
591
592
593
594
                    chunk = stream_response_class(
                        id=request_id,
                        object=chunk_object_type,
                        created=created_time,
                        choices=[choice_data],
                        model=model_name,
                    )
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609

                    # handle usage stats if requested & if continuous
                    if include_continuous_usage:
                        chunk.usage = UsageInfo(
                            prompt_tokens=num_prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=num_prompt_tokens + completion_tokens,
                        )

                    data = chunk.model_dump_json(exclude_unset=True)
                    yield f"data: {data}\n\n"

            # Once the final token is handled, if stream_options.include_usage
            # is sent, send the usage.
            if include_usage:
610
611
612
613
614
                final_usage = UsageInfo(
                    prompt_tokens=num_prompt_tokens,
                    completion_tokens=completion_tokens,
                    total_tokens=num_prompt_tokens + completion_tokens,
                )
615
616
617
618
619
620
621

                final_usage_chunk = stream_response_class(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
622
623
624
625
626
                    usage=final_usage,
                )
                final_usage_data = final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True
                )
627
628
629
630
631
632
                yield f"data: {final_usage_data}\n\n"

            # report to FastAPI middleware aggregate usage across all choices
            request_metadata.final_usage_info = UsageInfo(
                prompt_tokens=num_prompt_tokens,
                completion_tokens=completion_tokens,
633
634
                total_tokens=num_prompt_tokens + completion_tokens,
            )
635
636
637
638
639
640
641
642
643

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            logger.exception("Error in %s stream generator.", self.task_type)
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

644
645
646
    def _split_audio(
        self, audio_data: np.ndarray, sample_rate: int
    ) -> list[np.ndarray]:
647
648
        chunk_size = sample_rate * self.asr_config.max_audio_clip_s
        overlap_size = sample_rate * self.asr_config.overlap_chunk_second
649
650
651
652
653
654
655
656
657
658
659
        chunks = []
        i = 0
        while i < audio_data.shape[-1]:
            if i + chunk_size >= audio_data.shape[-1]:
                # handle last chunk
                chunks.append(audio_data[..., i:])
                break

            # Find the best split point in the overlap region
            search_start = i + chunk_size - overlap_size
            search_end = min(i + chunk_size, audio_data.shape[-1])
660
            split_point = self._find_split_point(audio_data, search_start, search_end)
661
662
663
664
665
666

            # Extract chunk up to the split point
            chunks.append(audio_data[..., i:split_point])
            i = split_point
        return chunks

667
668
    def _find_split_point(self, wav: np.ndarray, start_idx: int, end_idx: int) -> int:
        """Find the best point to split audio by
669
670
671
672
673
674
675
676
677
678
679
680
681
        looking for silence or low amplitude.
        Args:
            wav: Audio tensor [1, T]
            start_idx: Start index of search region
            end_idx: End index of search region
        Returns:
            Index of best splitting point
        """
        segment = wav[start_idx:end_idx]

        # Calculate RMS energy in small windows
        min_energy = math.inf
        quietest_idx = 0
682
683
684
        min_energy_window = self.asr_config.min_energy_split_window_size
        assert min_energy_window is not None
        for i in range(0, len(segment) - min_energy_window, min_energy_window):
685
686
            window = segment[i : i + min_energy_window]
            energy = (window**2).mean() ** 0.5
687
688
689
690
            if energy < min_energy:
                quietest_idx = i + start_idx
                min_energy = energy
        return quietest_idx