parakeet.py 11.6 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Modules below used for the audio encoder component in: models/nano_nemotron_vl.py
"""

from collections.abc import Iterable
8
9
from functools import cache
from typing import Any
10
11
12
13
14

import numpy as np
import torch
import torch.nn as nn
from transformers import ParakeetEncoder as HFParakeetEncoder
15
16
from transformers import PretrainedConfig
from transformers.audio_utils import mel_filter_bank
17

18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.activation import ReLUSquaredActivation
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
22
23
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.parakeet import ExtractorConfig, ParakeetConfig

24
25
logger = init_logger(__name__)

26
27
28
29
30
31
32
33
34

class ParakeetProjection(nn.Module):
    def __init__(self, config: ParakeetConfig) -> None:
        super().__init__()
        sound_hidden_size = config.hidden_size
        proj_hidden_size = config.projection_hidden_size
        llm_hidden_size = config.llm_hidden_size
        bias = config.projection_bias

35
        self.norm = RMSNorm(sound_hidden_size, eps=config.projection_eps)
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        self.linear1 = nn.Linear(sound_hidden_size, proj_hidden_size, bias=bias)
        self.activation = ReLUSquaredActivation()
        self.linear2 = nn.Linear(proj_hidden_size, llm_hidden_size, bias=bias)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.norm(hidden_states)
        hidden_states = self.linear1(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.linear2(hidden_states)
        return hidden_states


class ProjectedParakeet(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
        *,
        dtype: torch.dtype,
        llm_hidden_size: int,
        max_model_len: int,
    ) -> None:
        super().__init__()
        self.config = ParakeetConfig.from_hf_config(
            config, llm_hidden_size=llm_hidden_size, max_model_len=max_model_len
        )
        self.encoder = HFParakeetEncoder(self.config)
        self.encoder = self.encoder.to(dtype)
        self.projection = ParakeetProjection(self.config)
        self.projection = self.projection.to(dtype)

    def forward(
        self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        outputs = self.encoder(
            input_features=input_features, attention_mask=attention_mask
        )
        outputs = outputs.last_hidden_state
        outputs = outputs.to(dtype=torch.bfloat16)
        outputs = self.projection(outputs)
        return outputs

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loaded_params: set[str] = set()
        params_dict = dict(self.named_parameters())
        buffers_dict = dict(self.named_buffers())

        if isinstance(weights, dict):
            weights_list = list(weights.items())
        else:
            weights_list = list(weights)

        for name, weight in weights_list:
            if name.startswith("sound_encoder.encoder.feature_extractor."):
                # Feature extractor buffers are handled outside the encoder.
                continue
            if name.startswith("sound_encoder."):
                target_name = name[len("sound_encoder.") :]
            elif name.startswith("sound_projection."):
                target_name = f"projection.{name[len('sound_projection.') :]}"
            else:
                continue

            target = params_dict.get(target_name)
            if target is None:
                target = buffers_dict.get(target_name)
            if target is None:
                raise ValueError(f"Unknown weight: {name}")
            weight_loader = getattr(target, "weight_loader", default_weight_loader)
            with torch.no_grad():
                weight_loader(target, weight)
            loaded_params.add(target_name)

        return loaded_params


111
112
113
114
115
EPSILON = 1e-5
LOG_ZERO_GUARD_VALUE = 2**-24


class ParakeetExtractor:
116
117
    def __init__(self, config: PretrainedConfig) -> None:
        self.config = ExtractorConfig.from_hf_config(config)
118
        """`config` is named *exactly* for `._get_subsampling_output_length` below"""
119
        self._clip_target_samples = int(
120
            round(self.config.clip_duration_s * self.config.sampling_rate)
121
122
        )
        self._tail_min_samples = int(
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
            round(self.config.clip_min_duration_s * self.config.sampling_rate)
        )

    @staticmethod
    @cache
    def _get_window(win_length: int, device: str) -> torch.Tensor:
        return torch.hann_window(win_length, periodic=False, device=device)

    @staticmethod
    @cache
    def _get_mel_filters(
        feature_size: int, sampling_rate: int, n_fft: int, device: str
    ) -> torch.Tensor:
        filter_bank = mel_filter_bank(
            num_frequency_bins=n_fft // 2 + 1,
            num_mel_filters=feature_size,
            min_frequency=0.0,
            max_frequency=sampling_rate / 2,
            sampling_rate=sampling_rate,
            norm="slaney",
            mel_scale="slaney",
        )
        return torch.from_numpy(filter_bank.T).to(device=device, dtype=torch.float32)

    def _torch_extract_fbank_features(self, waveform: torch.Tensor, device: str):
        # spectrogram
        device = str(torch.device(device))
        cfg = self.config
        window = self._get_window(cfg.win_length, device)
        stft = torch.stft(
            waveform,
            self.config.n_fft,
            hop_length=cfg.hop_length,
            win_length=cfg.win_length,
            window=window,
            return_complex=True,
            pad_mode="constant",
        )
        mel_filters = self._get_mel_filters(
            cfg.feature_size, cfg.sampling_rate, cfg.n_fft, device
        )
        return self._apply_mel_filters(stft, mel_filters)

    @torch.compile(dynamic=True)
    def _apply_mel_filters(
        self, stft_output: torch.Tensor, mel_filters: torch.Tensor
    ) -> torch.Tensor:
        magnitudes = stft_output.real.square() + stft_output.imag.square()
        mel_spec = mel_filters @ magnitudes
        mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE)
        return mel_spec.permute(0, 2, 1)

    @torch.compile(dynamic=True)
    def _apply_preemphasis(
        self, input_features: torch.Tensor, audio_lengths: torch.Tensor
    ) -> torch.Tensor:
        timemask = torch.arange(
            input_features.shape[1], device=input_features.device
        ).unsqueeze(0) < audio_lengths.unsqueeze(1)
        input_features = torch.cat(
            [
                input_features[:, :1],
                input_features[:, 1:]
                - self.config.preemphasis * input_features[:, :-1],
            ],
            dim=1,
        )
        input_features = input_features.masked_fill(~timemask, 0.0)
        return input_features

    @torch.compile(dynamic=True)
    def _normalize_mel_features(
        self, mel_features: torch.Tensor, audio_lengths: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        features_lengths = torch.floor_divide(
            audio_lengths + self.config.n_fft // 2 * 2 - self.config.n_fft,
            self.config.hop_length,
        )
        attention_mask = (
            torch.arange(mel_features.shape[1], device=mel_features.device)[None, :]
            < features_lengths[:, None]
204
        )
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        mask = attention_mask.unsqueeze(-1)
        lengths = attention_mask.sum(dim=1)
        mel_features_masked = mel_features * mask
        mean = (mel_features_masked.sum(dim=1) / lengths.unsqueeze(-1)).unsqueeze(1)
        variance = ((mel_features_masked - mean) ** 2 * mask).sum(dim=1) / (
            lengths - 1
        ).unsqueeze(-1)
        std = torch.sqrt(variance).unsqueeze(1)
        return (mel_features - mean) / (std + EPSILON) * mask, attention_mask

    def _pad_raw_speech(
        self, raw_speech: list[torch.Tensor], max_len: int, device: str
    ) -> torch.Tensor:
        output = torch.full(
            (len(raw_speech), max_len),
            self.config.padding_value,
            device=device,
            dtype=torch.float32,
        )
        dsts = [output[i, : raw_speech[i].shape[0]] for i in range(len(raw_speech))]
        srcs = [s.squeeze(-1) for s in raw_speech]
        # single kernel horizontal fusion
        torch._foreach_copy_(dsts, srcs)
        return output
229

230
231
232
233
234
235
236
    def _clip_sizes(self, audio_len: int) -> list[int]:
        audio_len = max(audio_len, self._tail_min_samples)
        num_full_clips, remainder = divmod(audio_len, self._clip_target_samples)
        clip_sizes = [self._clip_target_samples] * num_full_clips
        if remainder > 0:
            clip_sizes.append(max(remainder, self._tail_min_samples))
        return clip_sizes
237
238

    def audio_token_count(self, audio_len: int) -> int:
239
240
        total_tokens = 0
        for clip_size in self._clip_sizes(audio_len):
241
            num_frames = clip_size // self.config.hop_length
242
243
244
245
246
247
            n_tokens = HFParakeetEncoder._get_subsampling_output_length(
                self, torch.tensor([num_frames], dtype=torch.float)
            )
            total_tokens += int(n_tokens.item())
        return max(1, total_tokens)

248
    def split_audio_into_clips(self, audio: torch.Tensor) -> list[torch.Tensor]:
249
250
251
252
253
        assert audio.ndim == 1
        audio_len = int(audio.shape[0])
        clip_sizes = self._clip_sizes(audio_len)
        target_len = sum(clip_sizes)
        if audio_len < target_len:
254
            audio = torch.nn.functional.pad(audio, (0, target_len - audio_len))
255

256
        clips = list[torch.Tensor]()
257
258
259
260
261
        offset = 0
        for clip_size in clip_sizes:
            clips.append(audio[offset : offset + clip_size])
            offset += clip_size
        return clips
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    def __call__(
        self,
        raw_speech: list[np.ndarray],
        *,
        device: str = "cpu",
    ) -> dict[str, Any]:
        raw_speech = [
            torch.as_tensor(speech, device=device, dtype=torch.float32)
            for speech in raw_speech
        ]

        for i, speech in enumerate(raw_speech):
            if len(speech.shape) > 1:
                logger.warning(
                    "Only mono-channel audio is supported for input to %s. "
                    "We will take the mean of the channels to convert to mono.",
                    self.__class__.__name__,
                )
                raw_speech[i] = speech.mean(-1)

        audio_clips = list[torch.Tensor]()
284
285
286
287
288
        audio_num_clips = list[int]()
        for audio in raw_speech:
            clips = self.split_audio_into_clips(audio)
            audio_clips.extend(clips)
            audio_num_clips.append(len(clips))
289
        raw_speech = audio_clips
290

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        audio_lengths = torch.tensor(
            [len(speech) for speech in raw_speech], dtype=torch.long, device=device
        )

        max_length = max(len(speech) for speech in raw_speech)
        input_features = self._pad_raw_speech(raw_speech, max_length, device)
        input_features = self._apply_preemphasis(input_features, audio_lengths)
        input_features = self._torch_extract_fbank_features(input_features, device)
        input_features, attention_mask = self._normalize_mel_features(
            input_features, audio_lengths
        )

        return {
            "input_audio_features": input_features,
            "feature_attention_mask": attention_mask,
            "audio_num_clips": audio_num_clips,
        }
308

309
310
311
312
    @staticmethod
    def audio_length(raw_config: PretrainedConfig, audio_tokens: int) -> int:
        config = ExtractorConfig.from_hf_config(raw_config)
        return int(audio_tokens * config.subsampling_factor * config.hop_length)