Unverified Commit 6adacfcb authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

ParakeetExtractor performance and UX enhancements (#39423)


Signed-off-by: default avatarNetanel Haber <58652339+netanel-haber@users.noreply.github.com>
parent 14cb86c1
......@@ -5,19 +5,24 @@ Modules below used for the audio encoder component in: models/nano_nemotron_vl.p
"""
from collections.abc import Iterable
from dataclasses import asdict
from functools import cache
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from transformers import ParakeetEncoder as HFParakeetEncoder
from transformers import ParakeetFeatureExtractor, PretrainedConfig
from transformers import PretrainedConfig
from transformers.audio_utils import mel_filter_bank
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.parakeet import ExtractorConfig, ParakeetConfig
logger = init_logger(__name__)
class ParakeetProjection(nn.Module):
def __init__(self, config: ParakeetConfig) -> None:
......@@ -103,16 +108,124 @@ class ProjectedParakeet(nn.Module):
return loaded_params
class ParakeetExtractor(ParakeetFeatureExtractor):
EPSILON = 1e-5
LOG_ZERO_GUARD_VALUE = 2**-24
class ParakeetExtractor:
def __init__(self, config: PretrainedConfig) -> None:
self.config = ExtractorConfig.from_hf_config(config)
super().__init__(**asdict(self.config))
"""`config` is named *exactly* for `._get_subsampling_output_length` below"""
self._clip_target_samples = int(
round(self.config.clip_duration_s * self.sampling_rate)
round(self.config.clip_duration_s * self.config.sampling_rate)
)
self._tail_min_samples = int(
round(self.config.clip_min_duration_s * self.sampling_rate)
round(self.config.clip_min_duration_s * self.config.sampling_rate)
)
@staticmethod
@cache
def _get_window(win_length: int, device: str) -> torch.Tensor:
return torch.hann_window(win_length, periodic=False, device=device)
@staticmethod
@cache
def _get_mel_filters(
feature_size: int, sampling_rate: int, n_fft: int, device: str
) -> torch.Tensor:
filter_bank = mel_filter_bank(
num_frequency_bins=n_fft // 2 + 1,
num_mel_filters=feature_size,
min_frequency=0.0,
max_frequency=sampling_rate / 2,
sampling_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)
return torch.from_numpy(filter_bank.T).to(device=device, dtype=torch.float32)
def _torch_extract_fbank_features(self, waveform: torch.Tensor, device: str):
# spectrogram
device = str(torch.device(device))
cfg = self.config
window = self._get_window(cfg.win_length, device)
stft = torch.stft(
waveform,
self.config.n_fft,
hop_length=cfg.hop_length,
win_length=cfg.win_length,
window=window,
return_complex=True,
pad_mode="constant",
)
mel_filters = self._get_mel_filters(
cfg.feature_size, cfg.sampling_rate, cfg.n_fft, device
)
return self._apply_mel_filters(stft, mel_filters)
@torch.compile(dynamic=True)
def _apply_mel_filters(
self, stft_output: torch.Tensor, mel_filters: torch.Tensor
) -> torch.Tensor:
magnitudes = stft_output.real.square() + stft_output.imag.square()
mel_spec = mel_filters @ magnitudes
mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE)
return mel_spec.permute(0, 2, 1)
@torch.compile(dynamic=True)
def _apply_preemphasis(
self, input_features: torch.Tensor, audio_lengths: torch.Tensor
) -> torch.Tensor:
timemask = torch.arange(
input_features.shape[1], device=input_features.device
).unsqueeze(0) < audio_lengths.unsqueeze(1)
input_features = torch.cat(
[
input_features[:, :1],
input_features[:, 1:]
- self.config.preemphasis * input_features[:, :-1],
],
dim=1,
)
input_features = input_features.masked_fill(~timemask, 0.0)
return input_features
@torch.compile(dynamic=True)
def _normalize_mel_features(
self, mel_features: torch.Tensor, audio_lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
features_lengths = torch.floor_divide(
audio_lengths + self.config.n_fft // 2 * 2 - self.config.n_fft,
self.config.hop_length,
)
attention_mask = (
torch.arange(mel_features.shape[1], device=mel_features.device)[None, :]
< features_lengths[:, None]
)
mask = attention_mask.unsqueeze(-1)
lengths = attention_mask.sum(dim=1)
mel_features_masked = mel_features * mask
mean = (mel_features_masked.sum(dim=1) / lengths.unsqueeze(-1)).unsqueeze(1)
variance = ((mel_features_masked - mean) ** 2 * mask).sum(dim=1) / (
lengths - 1
).unsqueeze(-1)
std = torch.sqrt(variance).unsqueeze(1)
return (mel_features - mean) / (std + EPSILON) * mask, attention_mask
def _pad_raw_speech(
self, raw_speech: list[torch.Tensor], max_len: int, device: str
) -> torch.Tensor:
output = torch.full(
(len(raw_speech), max_len),
self.config.padding_value,
device=device,
dtype=torch.float32,
)
dsts = [output[i, : raw_speech[i].shape[0]] for i in range(len(raw_speech))]
srcs = [s.squeeze(-1) for s in raw_speech]
# single kernel horizontal fusion
torch._foreach_copy_(dsts, srcs)
return output
def _clip_sizes(self, audio_len: int) -> list[int]:
audio_len = max(audio_len, self._tail_min_samples)
......@@ -125,39 +238,73 @@ class ParakeetExtractor(ParakeetFeatureExtractor):
def audio_token_count(self, audio_len: int) -> int:
total_tokens = 0
for clip_size in self._clip_sizes(audio_len):
num_frames = clip_size // self.hop_length
num_frames = clip_size // self.config.hop_length
n_tokens = HFParakeetEncoder._get_subsampling_output_length(
self, torch.tensor([num_frames], dtype=torch.float)
)
total_tokens += int(n_tokens.item())
return max(1, total_tokens)
def split_audio_into_clips(self, audio: np.ndarray) -> list[np.ndarray]:
def split_audio_into_clips(self, audio: torch.Tensor) -> list[torch.Tensor]:
assert audio.ndim == 1
audio_len = int(audio.shape[0])
clip_sizes = self._clip_sizes(audio_len)
target_len = sum(clip_sizes)
if audio_len < target_len:
audio = np.pad(audio, (0, target_len - audio_len))
audio = torch.nn.functional.pad(audio, (0, target_len - audio_len))
clips = list[np.ndarray]()
clips = list[torch.Tensor]()
offset = 0
for clip_size in clip_sizes:
clips.append(audio[offset : offset + clip_size])
offset += clip_size
return clips
def __call__(self, raw_speech: list[np.ndarray], *args, **kwargs):
audio_clips = list[np.ndarray]()
def __call__(
self,
raw_speech: list[np.ndarray],
*,
device: str = "cpu",
) -> dict[str, Any]:
raw_speech = [
torch.as_tensor(speech, device=device, dtype=torch.float32)
for speech in raw_speech
]
for i, speech in enumerate(raw_speech):
if len(speech.shape) > 1:
logger.warning(
"Only mono-channel audio is supported for input to %s. "
"We will take the mean of the channels to convert to mono.",
self.__class__.__name__,
)
raw_speech[i] = speech.mean(-1)
audio_clips = list[torch.Tensor]()
audio_num_clips = list[int]()
for audio in raw_speech:
clips = self.split_audio_into_clips(audio)
audio_clips.extend(clips)
audio_num_clips.append(len(clips))
raw_speech = audio_clips
outputs = super().__call__(audio_clips, *args, **kwargs)
outputs["audio_num_clips"] = audio_num_clips
return outputs
audio_lengths = torch.tensor(
[len(speech) for speech in raw_speech], dtype=torch.long, device=device
)
max_length = max(len(speech) for speech in raw_speech)
input_features = self._pad_raw_speech(raw_speech, max_length, device)
input_features = self._apply_preemphasis(input_features, audio_lengths)
input_features = self._torch_extract_fbank_features(input_features, device)
input_features, attention_mask = self._normalize_mel_features(
input_features, audio_lengths
)
return {
"input_audio_features": input_features,
"feature_attention_mask": attention_mask,
"audio_num_clips": audio_num_clips,
}
@staticmethod
def audio_length(raw_config: PretrainedConfig, audio_tokens: int) -> int:
......
......@@ -49,15 +49,24 @@ class ExtractorConfig:
clip_duration_s: int = 30
clip_min_duration_s: float = 0.1
@staticmethod
def from_hf_config(config: PretrainedConfig) -> "ExtractorConfig":
win_length: int = 400
preemphasis: float = 0.97
n_fft: int = 512
padding_value: float = 0.0
@classmethod
def from_hf_config(cls, config: PretrainedConfig) -> "ExtractorConfig":
assert isinstance(config, PretrainedConfig)
hop_length = int(getattr(config, "hop_length", ExtractorConfig.hop_length))
return ExtractorConfig(
defaults = ("hop_length", "win_length", "preemphasis", "n_fft", "padding_value")
optional_kwargs = {
name: getattr(config, name) for name in defaults if hasattr(config, name)
}
return cls(
feature_size=config.num_mel_bins,
sampling_rate=config.sampling_rate,
hop_length=hop_length,
subsampling_factor=config.subsampling_factor,
subsampling_conv_kernel_size=config.subsampling_conv_kernel_size,
subsampling_conv_stride=config.subsampling_conv_stride,
**optional_kwargs,
)
......@@ -992,17 +992,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
parts[idx] = audio_repl.full
audio_index += 1
text = ["".join(parts)]
audio_inputs = extractor(
audios,
sampling_rate=extractor.sampling_rate,
return_tensors="pt",
)
audio_inputs = {
"input_audio_features": audio_inputs.input_features,
"feature_attention_mask": audio_inputs.attention_mask,
"audio_num_clips": audio_inputs.audio_num_clips,
}
audio_inputs = extractor(audios)
return text, audio_inputs
def __call__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment