# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import numpy as np import torch import torch.nn as nn import torchaudio.compliance.kaldi as kaldi from torch.nn.utils.rnn import pad_sequence from transformers import ( AutoFeatureExtractor, BatchFeature, ) from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor from transformers.processing_utils import ProcessorMixin from transformers.utils import TensorType from vllm.logger import init_logger logger = init_logger(__name__) def apply_cmvn(inputs, cmvn): # noqa """ Apply CMVN with mvn data """ device = inputs.device # dtype = inputs.dtype frame, dim = inputs.shape means = cmvn[0:1, :dim] vars = cmvn[1:2, :dim] inputs += means.to(device) inputs *= vars.to(device) return inputs.type(torch.float32) def apply_lfr(inputs, lfr_m, lfr_n): # LFR_inputs = [] T = inputs.shape[0] T_lfr = int(np.ceil(T / lfr_n)) left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1) inputs = torch.vstack((left_padding, inputs)) T = T + (lfr_m - 1) // 2 feat_dim = inputs.shape[-1] strides = (lfr_n * feat_dim, 1) sizes = (T_lfr, lfr_m * feat_dim) last_idx = (T - lfr_m) // lfr_n + 1 num_padding = lfr_m - (T - last_idx * lfr_n) if num_padding > 0: num_padding = ( (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx) ) inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding)) LFR_outputs = inputs.as_strided(sizes, strides) return LFR_outputs.clone().type(torch.float32) def load_cmvn(cmvn_file): with open(cmvn_file, encoding="utf-8") as f: lines = f.readlines() means_list = [] vars_list = [] for i in range(len(lines)): line_item = lines[i].split() if line_item[0] == "": line_item = lines[i + 1].split() if line_item[0] == "": add_shift_line = line_item[3 : (len(line_item) - 1)] means_list = list(add_shift_line) continue elif line_item[0] == "": line_item = lines[i + 1].split() if line_item[0] == "": rescale_line = line_item[3 : (len(line_item) - 1)] vars_list = list(rescale_line) continue means = np.array(means_list).astype(np.float32) vars = np.array(vars_list).astype(np.float32) cmvn = np.array([means, vars]) cmvn = torch.as_tensor(cmvn, dtype=torch.float32) return cmvn class WavFrontend(nn.Module): """Conventional frontend structure for ASR.""" def __init__( self, cmvn_file: str = "null", fs: int = 16000, window: str = "hamming", n_mels: int = 80, frame_length: int = 25, frame_shift: int = 10, filter_length_min: int = -1, filter_length_max: int = -1, lfr_m: int = 1, lfr_n: int = 1, dither: float = 1.0, snip_edges: bool = True, upsacle_samples: bool = True, **kwargs, ): super().__init__() self.fs = fs self.window = window self.n_mels = n_mels self.frame_length = frame_length self.frame_shift = frame_shift self.filter_length_min = filter_length_min self.filter_length_max = filter_length_max self.lfr_m = lfr_m self.lfr_n = lfr_n self.cmvn_file = cmvn_file self.dither = dither self.snip_edges = snip_edges self.upsacle_samples = upsacle_samples self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file) def output_size(self) -> int: return self.n_mels * self.lfr_m def forward( self, input: torch.Tensor, input_lengths, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: batch_size = input.size(0) feats = [] feats_lens = [] for i in range(batch_size): waveform_length = input_lengths[i] waveform = input[i][:waveform_length] if self.upsacle_samples: waveform = waveform * (1 << 15) waveform = waveform.unsqueeze(0) mat = kaldi.fbank( waveform, num_mel_bins=self.n_mels, frame_length=min(self.frame_length, waveform_length / self.fs * 1000), frame_shift=self.frame_shift, dither=self.dither, energy_floor=0.0, window_type=self.window, sample_frequency=self.fs, snip_edges=self.snip_edges, ) if self.lfr_m != 1 or self.lfr_n != 1: mat = apply_lfr(mat, self.lfr_m, self.lfr_n) if self.cmvn is not None: mat = apply_cmvn(mat, self.cmvn) feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length) feats_lens = torch.as_tensor(feats_lens) if batch_size == 1: feats_pad = feats[0][None, :, :] else: feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0) return feats_pad, feats_lens def forward_fbank( self, input: torch.Tensor, input_lengths: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: batch_size = input.size(0) feats = [] feats_lens = [] for i in range(batch_size): waveform_length = input_lengths[i] waveform = input[i][:waveform_length] waveform = waveform * (1 << 15) waveform = waveform.unsqueeze(0) mat = kaldi.fbank( waveform, num_mel_bins=self.n_mels, frame_length=self.frame_length, frame_shift=self.frame_shift, dither=self.dither, energy_floor=0.0, window_type=self.window, sample_frequency=self.fs, ) feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length) feats_lens = torch.as_tensor(feats_lens) feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0) return feats_pad, feats_lens def forward_lfr_cmvn( self, input: torch.Tensor, input_lengths: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: batch_size = input.size(0) feats = [] feats_lens = [] for i in range(batch_size): mat = input[i, : input_lengths[i], :] if self.lfr_m != 1 or self.lfr_n != 1: mat = apply_lfr(mat, self.lfr_m, self.lfr_n) if self.cmvn is not None: mat = apply_cmvn(mat, self.cmvn) feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length) feats_lens = torch.as_tensor(feats_lens) feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0) return feats_pad, feats_lens class FunASRFeatureExtractor(SequenceFeatureExtractor): r""" Constructs a FunASR feature extractor. This feature extractor inherits from [`~feature_extraction_sequence_ utils.SequenceFeatureExtractor`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time Fourier Transform` which should match pytorch's `torch.stft` equivalent. Args: feature_size (`int`, *optional*, defaults to 80): The feature dimension of the extracted features. sampling_rate (`int`, *optional*, defaults to 16000): The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). hop_length (`int`, *optional*, defaults to 160): Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. chunk_length (`int`, *optional*, defaults to 30): The maximum number of chunks of `sampling_rate` samples used to trim and pad longer or shorter audio sequences. n_fft (`int`, *optional*, defaults to 400): Size of the Fourier transform. padding_value (`float`, *optional*, defaults to 0.0): Padding value used to pad the audio. Should correspond to silences. dither (`float`, *optional*, defaults to 0.0): Adds dithering. In other words, adds a small Gaussian noise to each frame. E.g. use 0.0001 to add dithering with a normal distribution centered around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech). The value 0.0 means no dithering. Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces the high log_mel_fbank values for signals with hard-zero sections, when VAD cutoff is present in the signal. """ model_input_names = ["input_features"] def __init__( self, feature_size=80, sampling_rate=16000, hop_length=160, chunk_length=30, n_fft=400, padding_value=0.0, dither=0.0, return_attention_mask=False, **kwargs, ): super().__init__( feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, return_attention_mask=return_attention_mask, **kwargs, ) self.frontend_conf = kwargs.get("frontend_conf", {}) self.n_fft = n_fft self.hop_length = hop_length self.chunk_length = chunk_length self.n_samples = chunk_length * sampling_rate self.nb_max_frames = self.n_samples // hop_length self.sampling_rate = sampling_rate self.dither = dither def extract_fbank( self, data, data_len=None, data_type: str = "sound", frontend=None, **kwargs ): if isinstance(data, np.ndarray): data = torch.from_numpy(data) if len(data.shape) < 2: data = data[None, :] # data: [batch, N] data_len = [data.shape[1]] if data_len is None else data_len elif isinstance(data, torch.Tensor): if len(data.shape) < 2: data = data[None, :] # data: [batch, N] data_len = [data.shape[1]] if data_len is None else data_len elif isinstance(data, (list, tuple)): data_list, data_len = [], [] for data_i in data: if isinstance(data_i, np.ndarray): data_i = torch.from_numpy(data_i) data_list.append(data_i) data_len.append(data_i.shape[0]) data = pad_sequence(data_list, batch_first=True) data, data_len = frontend(data, data_len, **kwargs) if isinstance(data_len, (list, tuple)): data_len = torch.tensor([data_len]) return data.to(torch.float32), data_len.to(torch.int32) def __call__( self, raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], truncation: bool = True, pad_to_multiple_of: int | None = None, return_tensors: str | TensorType | None = None, return_attention_mask: bool | None = None, padding: str | None = "max_length", max_length: int | None = None, sampling_rate: int | None = None, do_normalize: bool | None = None, device: str | None = "cpu", return_token_timestamps: bool | None = None, **kwargs, ) -> BatchFeature: is_batched = isinstance(raw_speech, (list, tuple)) and ( isinstance(raw_speech[0], (np.ndarray, tuple, list)) ) if is_batched: raw_speech = [ np.asarray([speech], dtype=np.float32).T for speech in raw_speech ] elif not is_batched and not isinstance(raw_speech, np.ndarray): raw_speech = np.asarray(raw_speech, dtype=np.float32) elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype( np.float64 ): raw_speech = raw_speech.astype(np.float32) if not is_batched: raw_speech = [np.asarray([raw_speech]).T] batched_speech = BatchFeature({"input_features": raw_speech}) padded_inputs = self.pad( batched_speech, padding=padding, max_length=max_length if max_length else self.n_samples, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask or do_normalize, ) input_features = padded_inputs.get("input_features").transpose(2, 0, 1) frontend = WavFrontend(**self.frontend_conf, dither=self.dither) input_features, speech_lengths = self.extract_fbank( input_features[0], data_type=kwargs.get("data_type", "sound"), frontend=frontend, is_final=True, ) olens = 1 + (speech_lengths - 3 + 2 * 1) // 2 olens = 1 + (olens - 3 + 2 * 1) // 2 fake_token_lengths = (olens - 1) // 2 + 1 if isinstance(input_features[0], list): padded_inputs["input_features"] = [ np.asarray(feature, dtype=np.float32) for feature in input_features ] else: padded_inputs["input_features"] = input_features if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) fake_token_lengths = torch.clamp(fake_token_lengths, min=1) padded_inputs["speech_lengths"] = speech_lengths padded_inputs["fake_token_lengths"] = fake_token_lengths return padded_inputs class FunASRProcessor(ProcessorMixin): r""" Constructs a FunASR processor which wraps a FunASR feature extractor and a FunASR tokenizer into a single processor. [`FunASRProcessor`] offers all the functionalities of [`FunASRFeatureExtractor`] and [`Qwen2Tokenizer`]. See the [`~FunASRProcessor.__call__`] and [`~FunASRProcessor.decode`] for more information. Args: feature_extractor (`FunASRFeatureExtractor`): An instance of [`FunASRFeatureExtractor`]. The feature extractor is a required input. tokenizer (`Qwen2Tokenizer`): An instance of [`Qwen2Tokenizer`]. The tokenizer is a required input. """ feature_extractor_class = "FunASRFeatureExtractor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, feature_extractor, tokenizer, audio_token="<|AUDIO|>", ): super().__init__(feature_extractor, tokenizer) self.current_processor = self.feature_extractor self._in_target_context_manager = False self.audio_token = ( tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token ) self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token) def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): return self.tokenizer.get_decoder_prompt_ids( task=task, language=language, no_timestamps=no_timestamps ) def __call__(self, *args, **kwargs): """ Forwards the `audio` argument to FunASRFeatureExtractor's [`~FunASRFeatureExtractor.__call__`] and the `text` argument to [`~Qwen2Tokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. """ if self._in_target_context_manager: return self.current_processor(*args, **kwargs) audio = kwargs.pop("audio", None) sampling_rate = kwargs.pop("sampling_rate", None) text = kwargs.pop("text", None) if len(args) > 0: audio = args[0] args = args[1:] if text is None: raise ValueError("You need to specify `text` input to process.") elif isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError( "Invalid input text. Please provide a string, or a list of strings" ) if audio is not None: # ensure we have as much audios as audio tokens num_audio_tokens = sum(sample.count(self.audio_token) for sample in text) num_audios = 1 if type(audio) is np.ndarray else len(audio) if num_audio_tokens != num_audios: raise ValueError( f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}" # noqa: E501 ) inputs = self.feature_extractor( audio, *args, sampling_rate=sampling_rate, **kwargs ) expanded_text = [] for sample in text: replace_str = [] while self.audio_token in sample: num_audio_tokens = inputs["fake_token_lengths"].item() expanded_audio_token = self.audio_token * num_audio_tokens replace_str.append(expanded_audio_token) sample = sample.replace(self.audio_token, "", 1) while "" in sample: sample = sample.replace("", replace_str.pop(0), 1) expanded_text.append(sample) text = expanded_text if text is not None: encodings = self.tokenizer(text, **kwargs) if text is None: return inputs elif audio is None: return encodings else: inputs["labels"] = encodings["input_ids"] return inputs def get_prompt_ids(self, text: str, return_tensors="np"): return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors) AutoFeatureExtractor.register("FunASRFeatureExtractor", FunASRFeatureExtractor)