Unverified Commit 6adacfcb authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

ParakeetExtractor performance and UX enhancements (#39423)


Signed-off-by: default avatarNetanel Haber <58652339+netanel-haber@users.noreply.github.com>
parent 14cb86c1
...@@ -5,19 +5,24 @@ Modules below used for the audio encoder component in: models/nano_nemotron_vl.p ...@@ -5,19 +5,24 @@ Modules below used for the audio encoder component in: models/nano_nemotron_vl.p
""" """
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import asdict from functools import cache
from typing import Any
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import ParakeetEncoder as HFParakeetEncoder from transformers import ParakeetEncoder as HFParakeetEncoder
from transformers import ParakeetFeatureExtractor, PretrainedConfig from transformers import PretrainedConfig
from transformers.audio_utils import mel_filter_bank
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.parakeet import ExtractorConfig, ParakeetConfig from vllm.transformers_utils.configs.parakeet import ExtractorConfig, ParakeetConfig
logger = init_logger(__name__)
class ParakeetProjection(nn.Module): class ParakeetProjection(nn.Module):
def __init__(self, config: ParakeetConfig) -> None: def __init__(self, config: ParakeetConfig) -> None:
...@@ -103,16 +108,124 @@ class ProjectedParakeet(nn.Module): ...@@ -103,16 +108,124 @@ class ProjectedParakeet(nn.Module):
return loaded_params return loaded_params
class ParakeetExtractor(ParakeetFeatureExtractor): EPSILON = 1e-5
LOG_ZERO_GUARD_VALUE = 2**-24
class ParakeetExtractor:
def __init__(self, config: PretrainedConfig) -> None: def __init__(self, config: PretrainedConfig) -> None:
self.config = ExtractorConfig.from_hf_config(config) self.config = ExtractorConfig.from_hf_config(config)
super().__init__(**asdict(self.config)) """`config` is named *exactly* for `._get_subsampling_output_length` below"""
self._clip_target_samples = int( self._clip_target_samples = int(
round(self.config.clip_duration_s * self.sampling_rate) round(self.config.clip_duration_s * self.config.sampling_rate)
) )
self._tail_min_samples = int( self._tail_min_samples = int(
round(self.config.clip_min_duration_s * self.sampling_rate) round(self.config.clip_min_duration_s * self.config.sampling_rate)
)
@staticmethod
@cache
def _get_window(win_length: int, device: str) -> torch.Tensor:
return torch.hann_window(win_length, periodic=False, device=device)
@staticmethod
@cache
def _get_mel_filters(
feature_size: int, sampling_rate: int, n_fft: int, device: str
) -> torch.Tensor:
filter_bank = mel_filter_bank(
num_frequency_bins=n_fft // 2 + 1,
num_mel_filters=feature_size,
min_frequency=0.0,
max_frequency=sampling_rate / 2,
sampling_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)
return torch.from_numpy(filter_bank.T).to(device=device, dtype=torch.float32)
def _torch_extract_fbank_features(self, waveform: torch.Tensor, device: str):
# spectrogram
device = str(torch.device(device))
cfg = self.config
window = self._get_window(cfg.win_length, device)
stft = torch.stft(
waveform,
self.config.n_fft,
hop_length=cfg.hop_length,
win_length=cfg.win_length,
window=window,
return_complex=True,
pad_mode="constant",
)
mel_filters = self._get_mel_filters(
cfg.feature_size, cfg.sampling_rate, cfg.n_fft, device
)
return self._apply_mel_filters(stft, mel_filters)
@torch.compile(dynamic=True)
def _apply_mel_filters(
self, stft_output: torch.Tensor, mel_filters: torch.Tensor
) -> torch.Tensor:
magnitudes = stft_output.real.square() + stft_output.imag.square()
mel_spec = mel_filters @ magnitudes
mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE)
return mel_spec.permute(0, 2, 1)
@torch.compile(dynamic=True)
def _apply_preemphasis(
self, input_features: torch.Tensor, audio_lengths: torch.Tensor
) -> torch.Tensor:
timemask = torch.arange(
input_features.shape[1], device=input_features.device
).unsqueeze(0) < audio_lengths.unsqueeze(1)
input_features = torch.cat(
[
input_features[:, :1],
input_features[:, 1:]
- self.config.preemphasis * input_features[:, :-1],
],
dim=1,
)
input_features = input_features.masked_fill(~timemask, 0.0)
return input_features
@torch.compile(dynamic=True)
def _normalize_mel_features(
self, mel_features: torch.Tensor, audio_lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
features_lengths = torch.floor_divide(
audio_lengths + self.config.n_fft // 2 * 2 - self.config.n_fft,
self.config.hop_length,
)
attention_mask = (
torch.arange(mel_features.shape[1], device=mel_features.device)[None, :]
< features_lengths[:, None]
) )
mask = attention_mask.unsqueeze(-1)
lengths = attention_mask.sum(dim=1)
mel_features_masked = mel_features * mask
mean = (mel_features_masked.sum(dim=1) / lengths.unsqueeze(-1)).unsqueeze(1)
variance = ((mel_features_masked - mean) ** 2 * mask).sum(dim=1) / (
lengths - 1
).unsqueeze(-1)
std = torch.sqrt(variance).unsqueeze(1)
return (mel_features - mean) / (std + EPSILON) * mask, attention_mask
def _pad_raw_speech(
self, raw_speech: list[torch.Tensor], max_len: int, device: str
) -> torch.Tensor:
output = torch.full(
(len(raw_speech), max_len),
self.config.padding_value,
device=device,
dtype=torch.float32,
)
dsts = [output[i, : raw_speech[i].shape[0]] for i in range(len(raw_speech))]
srcs = [s.squeeze(-1) for s in raw_speech]
# single kernel horizontal fusion
torch._foreach_copy_(dsts, srcs)
return output
def _clip_sizes(self, audio_len: int) -> list[int]: def _clip_sizes(self, audio_len: int) -> list[int]:
audio_len = max(audio_len, self._tail_min_samples) audio_len = max(audio_len, self._tail_min_samples)
...@@ -125,39 +238,73 @@ class ParakeetExtractor(ParakeetFeatureExtractor): ...@@ -125,39 +238,73 @@ class ParakeetExtractor(ParakeetFeatureExtractor):
def audio_token_count(self, audio_len: int) -> int: def audio_token_count(self, audio_len: int) -> int:
total_tokens = 0 total_tokens = 0
for clip_size in self._clip_sizes(audio_len): for clip_size in self._clip_sizes(audio_len):
num_frames = clip_size // self.hop_length num_frames = clip_size // self.config.hop_length
n_tokens = HFParakeetEncoder._get_subsampling_output_length( n_tokens = HFParakeetEncoder._get_subsampling_output_length(
self, torch.tensor([num_frames], dtype=torch.float) self, torch.tensor([num_frames], dtype=torch.float)
) )
total_tokens += int(n_tokens.item()) total_tokens += int(n_tokens.item())
return max(1, total_tokens) return max(1, total_tokens)
def split_audio_into_clips(self, audio: np.ndarray) -> list[np.ndarray]: def split_audio_into_clips(self, audio: torch.Tensor) -> list[torch.Tensor]:
assert audio.ndim == 1 assert audio.ndim == 1
audio_len = int(audio.shape[0]) audio_len = int(audio.shape[0])
clip_sizes = self._clip_sizes(audio_len) clip_sizes = self._clip_sizes(audio_len)
target_len = sum(clip_sizes) target_len = sum(clip_sizes)
if audio_len < target_len: if audio_len < target_len:
audio = np.pad(audio, (0, target_len - audio_len)) audio = torch.nn.functional.pad(audio, (0, target_len - audio_len))
clips = list[np.ndarray]() clips = list[torch.Tensor]()
offset = 0 offset = 0
for clip_size in clip_sizes: for clip_size in clip_sizes:
clips.append(audio[offset : offset + clip_size]) clips.append(audio[offset : offset + clip_size])
offset += clip_size offset += clip_size
return clips return clips
def __call__(self, raw_speech: list[np.ndarray], *args, **kwargs): def __call__(
audio_clips = list[np.ndarray]() self,
raw_speech: list[np.ndarray],
*,
device: str = "cpu",
) -> dict[str, Any]:
raw_speech = [
torch.as_tensor(speech, device=device, dtype=torch.float32)
for speech in raw_speech
]
for i, speech in enumerate(raw_speech):
if len(speech.shape) > 1:
logger.warning(
"Only mono-channel audio is supported for input to %s. "
"We will take the mean of the channels to convert to mono.",
self.__class__.__name__,
)
raw_speech[i] = speech.mean(-1)
audio_clips = list[torch.Tensor]()
audio_num_clips = list[int]() audio_num_clips = list[int]()
for audio in raw_speech: for audio in raw_speech:
clips = self.split_audio_into_clips(audio) clips = self.split_audio_into_clips(audio)
audio_clips.extend(clips) audio_clips.extend(clips)
audio_num_clips.append(len(clips)) audio_num_clips.append(len(clips))
raw_speech = audio_clips
outputs = super().__call__(audio_clips, *args, **kwargs) audio_lengths = torch.tensor(
outputs["audio_num_clips"] = audio_num_clips [len(speech) for speech in raw_speech], dtype=torch.long, device=device
return outputs )
max_length = max(len(speech) for speech in raw_speech)
input_features = self._pad_raw_speech(raw_speech, max_length, device)
input_features = self._apply_preemphasis(input_features, audio_lengths)
input_features = self._torch_extract_fbank_features(input_features, device)
input_features, attention_mask = self._normalize_mel_features(
input_features, audio_lengths
)
return {
"input_audio_features": input_features,
"feature_attention_mask": attention_mask,
"audio_num_clips": audio_num_clips,
}
@staticmethod @staticmethod
def audio_length(raw_config: PretrainedConfig, audio_tokens: int) -> int: def audio_length(raw_config: PretrainedConfig, audio_tokens: int) -> int:
......
...@@ -49,15 +49,24 @@ class ExtractorConfig: ...@@ -49,15 +49,24 @@ class ExtractorConfig:
clip_duration_s: int = 30 clip_duration_s: int = 30
clip_min_duration_s: float = 0.1 clip_min_duration_s: float = 0.1
@staticmethod win_length: int = 400
def from_hf_config(config: PretrainedConfig) -> "ExtractorConfig": preemphasis: float = 0.97
n_fft: int = 512
padding_value: float = 0.0
@classmethod
def from_hf_config(cls, config: PretrainedConfig) -> "ExtractorConfig":
assert isinstance(config, PretrainedConfig) assert isinstance(config, PretrainedConfig)
hop_length = int(getattr(config, "hop_length", ExtractorConfig.hop_length)) defaults = ("hop_length", "win_length", "preemphasis", "n_fft", "padding_value")
return ExtractorConfig( optional_kwargs = {
name: getattr(config, name) for name in defaults if hasattr(config, name)
}
return cls(
feature_size=config.num_mel_bins, feature_size=config.num_mel_bins,
sampling_rate=config.sampling_rate, sampling_rate=config.sampling_rate,
hop_length=hop_length,
subsampling_factor=config.subsampling_factor, subsampling_factor=config.subsampling_factor,
subsampling_conv_kernel_size=config.subsampling_conv_kernel_size, subsampling_conv_kernel_size=config.subsampling_conv_kernel_size,
subsampling_conv_stride=config.subsampling_conv_stride, subsampling_conv_stride=config.subsampling_conv_stride,
**optional_kwargs,
) )
...@@ -992,17 +992,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -992,17 +992,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
parts[idx] = audio_repl.full parts[idx] = audio_repl.full
audio_index += 1 audio_index += 1
text = ["".join(parts)] text = ["".join(parts)]
audio_inputs = extractor( audio_inputs = extractor(audios)
audios,
sampling_rate=extractor.sampling_rate,
return_tensors="pt",
)
audio_inputs = {
"input_audio_features": audio_inputs.input_features,
"feature_attention_mask": audio_inputs.attention_mask,
"audio_num_clips": audio_inputs.audio_num_clips,
}
return text, audio_inputs return text, audio_inputs
def __call__( def __call__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment