Unverified Commit c236a621 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[CLAP] Add CLAP to the library (#21370)



* add model like clip

* update

* text model ok

* clap text works

* some refactor

- `CLAPVision` to `CLAPAudio`
- refactor kwargs of audio modules

* more refactor

* more refactor

* more refactor

* correct fusion

* more refactor

* new modules

* add basic processor

* fixup

* remove whisper copioed from

* audio logits match

* add doc

* correct filters mel and add maxlength

* style

* few fixes

* forward passes

* fixup

* fixup

* some clean up

* remove mels form the dictionnary

* pad after the repeat

* update padding when dsmaller

* fix padding

* style

* use swin patch merging

* use copied from swin

* processor with any tokenizer

* more copied from

* some clean up

* more refactor

* fix mel when rand_trunc

* style

* remove unused imports

* update processing

* remove image processing tests

* add testing fiel

* fixmodeling issues

* replace with `is_longer`

* clap in serialization

* more refactor

* `make fixup`

* make fixup

* fix feature extractor

* update test feature extractor

* `make fixup`

* clean up config

* more clean up

* more cleanup

* update tests

* refactor tests and inits

* removeCLAP vision config

* remove CLAP from image procssing auto and dummy vision objects

* update inits

* style

* re order classes in modeling clap

* Use roberta tokenizer as the other weights are not open sourced

* small cleaup

* remove tokenization CLAP

* processor tokenizr is roberta

* update feature extraction doc

* remove vclap from model zero shot

* update f_min and f_max to frequency_xx

* some changes

- fix modeling keys
- add `is_longer` in the forward pass
- make fixup

* make fixup

* consistent behavior ebtween rand_crop and fusion

* add numpy resize and bilinear and documentation

* move resizing to image utils

* clean feature extraction

* import resize from correct file

* resize in image transforms

* update

* style

* style

* nit

* remove unused arguments form the feature extractor

* style

* few fixes + make fixup

* oops

* fix more tests

* add zero shot audio classification pipeline

* update zeroshot classification pipeline

* fixup

* fix copies

* all CI tests pass

* make fixup + fix docs

* fix docs

* fix docs

* update tests pip;eline

* update zero shot pipeline

* update feature extraction clap

* update tokenization auto

* use nested simplify

* update pipeline tests

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* split in two lines

* fixes

* refactor

* clean up

* add integration tests

* update config docstring

* style

* update processor

* fix processor test

* fix feat extractor tests

* update docs

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix readmes

* fix tips

* Update src/transformers/models/auto/configuration_auto.py

* update doc and remove todo -> properly explained

* fix idx and typo

* typoe

* cleanup config

* cleanup tests, styles and doc

* ignore docstyle on image transform

* add conversion script

* remove the `clap` indx in favor of `CLAP`

* update __init

* nits

* Update src/transformers/pipelines/__init__.py

* fix bug

* clarifiy config

* fix copy

* fix init

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix model output

* fix comment

* make fixup

* make fixup

* rename to `Clap`

* replace to `Clap`

* replace to `Clap`

* repo consistency

* again repo-consistency

* make fixup

* Apply suggestions from code review
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* add config

* changes

* update conversion

* Apply suggestions from code review
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove unused function

* update based on code reviews

* style

* more comments

* cleanup

* clean up

* style

* apply suggestions

* Empty commit

* pipeline will be added in a different PR

* update calls to audio utils functions

* update pipeline init

* style

* style

* styling again

* use pad

* fix repo-consistency

* update utils and add doc for audio utils

* clean up resize by using torch. update inits accordingly

* style

* CLap's  tokenizer is RobertA

* add audio utils to internal toctreee

* update totctree

* style

* update documentation and normalize naming accross audio utils and feature extraction clap

* style

* clean up

* update doc and typos

* fix doctest

* update modelin code, got rid of a lot of reshaping

* style on added doc audio utils

* update modeling clap

* style

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* docstringvariables with CLAP

* rename key

* update modeling CLAP

* update audio utils docstring

* update processing clap

* fix readmes

* fix toctree

* udpate configuration clap

* fix init

* make fixup

* fix

* fix

* update naming

* update

* update checkpoint path

* Apply suggestions from code review

* Major refactoring

* Update src/transformers/models/clap/configuration_clap.py

* merge

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 6b0257de
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_clap": [
"CLAP_PRETRAINED_MODEL_ARCHIVE_LIST",
"ClapAudioConfig",
"ClapConfig",
"ClapTextConfig",
],
"processing_clap": ["ClapProcessor"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_clap"] = [
"CLAP_PRETRAINED_MODEL_ARCHIVE_LIST",
"ClapModel",
"ClapPreTrainedModel",
"ClapTextModel",
"ClapTextModelWithProjection",
"ClapAudioModel",
"ClapAudioModelWithProjection",
]
_import_structure["feature_extraction_clap"] = ["ClapFeatureExtractor"]
if TYPE_CHECKING:
from .configuration_clap import (
CLAP_PRETRAINED_MODEL_ARCHIVE_LIST,
ClapAudioConfig,
ClapConfig,
ClapTextConfig,
)
from .processing_clap import ClapProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_clap import ClapFeatureExtractor
from .modeling_clap import (
CLAP_PRETRAINED_MODEL_ARCHIVE_LIST,
ClapAudioModel,
ClapAudioModelWithProjection,
ClapModel,
ClapPreTrainedModel,
ClapTextModel,
ClapTextModelWithProjection,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
This diff is collapsed.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import re
import torch
from CLAP import create_model
from transformers import AutoFeatureExtractor, ClapConfig, ClapModel
KEYS_TO_MODIFY_MAPPING = {
"text_branch": "text_model",
"audio_branch": "audio_model.audio_encoder",
"attn": "attention.self",
"self.proj": "output.dense",
"attention.self_mask": "attn_mask",
"mlp.fc1": "intermediate.dense",
"mlp.fc2": "output.dense",
"norm1": "layernorm_before",
"norm2": "layernorm_after",
"bn0": "batch_norm",
}
processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")
def init_clap(checkpoint_path, enable_fusion=False):
model, model_cfg = create_model(
"HTSAT-tiny",
"roberta",
checkpoint_path,
precision="fp32",
device="cuda:0" if torch.cuda.is_available() else "cpu",
enable_fusion=enable_fusion,
fusion_type="aff_2d" if enable_fusion else None,
)
return model, model_cfg
def rename_state_dict(state_dict):
model_state_dict = {}
sequential_layers_pattern = r".*sequential.(\d+).*"
text_projection_pattern = r".*_projection.(\d+).*"
for key, value in state_dict.items():
# check if any key needs to be modified
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
if re.match(sequential_layers_pattern, key):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
# Because in CLAP they use `nn.Sequential`...
transformers_projection_layer = 1 if projecton_layer == 0 else 2
key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
if "audio" and "qkv" in key:
# split qkv into query key and value
mixed_qkv = value
qkv_dim = mixed_qkv.size(0) // 3
query_layer = mixed_qkv[:qkv_dim]
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
value_layer = mixed_qkv[qkv_dim * 2 :]
model_state_dict[key.replace("qkv", "query")] = query_layer
model_state_dict[key.replace("qkv", "key")] = key_layer
model_state_dict[key.replace("qkv", "value")] = value_layer
else:
model_state_dict[key] = value
return model_state_dict
def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, enable_fusion=False):
clap_model, clap_model_cfg = init_clap(checkpoint_path, enable_fusion=enable_fusion)
clap_model.eval()
state_dict = clap_model.state_dict()
state_dict = rename_state_dict(state_dict)
transformers_config = ClapConfig()
transformers_config.audio_config.enable_fusion = enable_fusion
model = ClapModel(transformers_config)
# ignore the spectrogram embedding layer
model.load_state_dict(state_dict, strict=False)
model.save_pretrained(pytorch_dump_folder_path)
transformers_config.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
args = parser.parse_args()
convert_clap_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.enable_fusion)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature extractor class for CLAP."""
import copy
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from ...audio_utils import fram_wave, get_mel_filter_banks, power_to_db, stft
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
class ClapFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a CLAP feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the *Short Time
Fourier Transform* (STFT) which should match pytorch's `torch.stft` equivalent.
Args:
feature_size (`int`, defaults to 64):
The feature dimension of the extracted Mel spectrograms. This corresponds to the number of mel filters
(`n_mels`).
sampling_rate (`int`, defaults to 48_000):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). This only serves
to warn users if the audio fed to the feature extractor does not have the same sampling rate.
hop_length (`int`, defaults to 480):
Length of the overlaping windows for the STFT used to obtain the Mel Spectrogram. The audio will be split
in smaller `frames` with a step of `hop_length` between each frame.
max_length_s (`int`, defaults to 10):
The maximum input lenght of the model in seconds. This is used to pad the audio.
fft_window_size (`int`, defaults to 1024):
Size of the window (in samples) on which the Fourier transform is applied. This controls the frequency
resolution of the spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples.
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
return_attention_mask (`bool`, *optional*, defaults to `False`):
Whether or not the model should return the attention masks coresponding to the input.
frequency_min (`float`, *optional*, default to 0):
The lowest frequency of interest. The STFT will not be computed for values below this.
frequency_max (`float`, *optional*, default to 14_000):
The highest frequency of interest. The STFT will not be computed for values above this.
top_db (`float`, *optional*):
The highest decibel value used to convert the mel spectrogram to the log scale. For more details see the
`audio_utils.power_to_db` function
truncation (`str`, *optional*, default to `"fusions"`):
Truncation pattern for long audio inputs. Two patterns are available:
- `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and a
downsampled version of the entire mel spectrogram.
If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a copy
of the original mel obtained from the padded audio.
- `rand_trunc` will select a random crop of the mel spectrogram.
padding (`str`, *optional*, defaults to `"repeatpad"`):
Padding pattern for shorter audio inputs. Three patterns were originally implemented:
- `repeatpad`: the audio is repeated, and then padded to fit the `max_length`.
- `repeat`: the audio is repeated and then cut to fit the `max_length`
- `pad`: the audio is padded.
"""
model_input_names = ["input_features", "is_longer"]
def __init__(
self,
feature_size=64,
sampling_rate=48_000,
hop_length=480,
max_length_s=10,
fft_window_size=1024,
padding_value=0.0,
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
frequency_min: float = 0,
frequency_max: float = 14_000,
top_db: int = None,
truncation: str = "fusion",
padding: str = "repeatpad",
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.top_db = top_db
self.truncation = truncation
self.padding = padding
self.fft_window_size = fft_window_size
self.nb_frequency_bins = (fft_window_size >> 1) + 1
self.hop_length = hop_length
self.max_length_s = max_length_s
self.nb_max_samples = max_length_s * sampling_rate
self.sampling_rate = sampling_rate
self.frequency_min = frequency_min
self.frequency_max = frequency_max
self.mel_filters = get_mel_filter_banks(
nb_frequency_bins=self.nb_frequency_bins,
nb_mel_filters=feature_size,
frequency_min=frequency_min,
frequency_max=frequency_max,
sample_rate=sampling_rate,
norm=None,
mel_scale="htk",
)
self.mel_filters_slaney = get_mel_filter_banks(
nb_frequency_bins=self.nb_frequency_bins,
nb_mel_filters=feature_size,
frequency_min=frequency_min,
frequency_max=frequency_max,
sample_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, excpet for the
mel filter banks, which do not need to be saved or printed as they are too long.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__
if "mel_filters" in output:
del output["mel_filters"]
if "mel_filters_slaney" in output:
del output["mel_filters_slaney"]
return output
def _np_extract_fbank_features(self, waveform: np.array, mel_filters: Optional[np.array] = None) -> np.ndarray:
"""
Compute the log-Mel spectrogram of the provided `waveform` using the `hanning` window. In CLAP, two different
filter banks are used depending on the truncation pattern:
- `self.mel_filters`: they correspond to the defaults parameters of `torchaduio` which can be obtained from
calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation`
is set to `"fusion"`.
- `self.mel_filteres_slaney` : they correspond to the defaults parameters of `torchlibrosa` which used
`librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original
implementation when the truncation mode is not `"fusion"`.
"""
window = np.hanning(self.fft_window_size + 1)[:-1]
frames = fram_wave(waveform, self.hop_length, self.fft_window_size)
spectrogram = stft(frames, window, fft_window_size=self.fft_window_size)
magnitudes = np.abs(spectrogram) ** 2
mel_spectrogram = np.matmul(mel_filters.T, magnitudes)
log_mel_spectrogram = power_to_db(mel_spectrogram).T
log_mel_spectrogram = np.asarray(log_mel_spectrogram, np.float32)
return log_mel_spectrogram
def _random_mel_fusion(self, mel, total_frames, chunk_frames):
ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
if len(ranges[1]) == 0:
# if the audio is too short, we just use the first chunk
ranges[1] = [0]
if len(ranges[2]) == 0:
# if the audio is too short, we just use the first chunk
ranges[2] = [0]
# randomly choose index for each part
idx_front = np.random.choice(ranges[0])
idx_middle = np.random.choice(ranges[1])
idx_back = np.random.choice(ranges[2])
mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :]
mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :]
mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :]
mel = torch.tensor(mel[None, None, :])
mel_shrink = torch.nn.functional.interpolate(
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False, antialias=False
)
mel_shrink = mel_shrink[0][0].numpy()
mel_fusion = np.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], axis=0)
return mel_fusion
def _get_input_mel(self, waveform: np.array, max_length, truncation, padding) -> np.array:
"""
Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments.
Four different path are possible:
- `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram
will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram
are then stacked together. They will later be used for `feature_fusion`.
- `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is
padded based on `padding`.
- `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded
based on `padding`, and is repeated `4` times.
- `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel
spectrogram will be computed on a random crop of the waveform.
"""
if waveform.shape[0] > max_length:
if truncation == "rand_trunc":
longer = True
# random crop to max_length (for compatibility) -> this should be handled by self.pad
overflow = len(waveform) - max_length
idx = np.random.randint(0, overflow + 1)
waveform = waveform[idx : idx + max_length]
input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :]
elif truncation == "fusion":
mel = self._np_extract_fbank_features(waveform, self.mel_filters)
chunk_frames = max_length // self.hop_length + 1 # the +1 related to how the spectrogram is computed
total_frames = mel.shape[0]
if chunk_frames == total_frames:
# there is a corner case where the audio length is larger than max_length but smaller than max_length+hop_length.
# In this case, we just use the whole audio.
input_mel = np.stack([mel, mel, mel, mel], axis=0)
longer = False
else:
input_mel = self._random_mel_fusion(mel, total_frames, chunk_frames)
longer = True
else:
raise NotImplementedError(f"data_truncating {truncation} not implemented")
else:
longer = False
# only use repeat as a new possible value for padding. you repeat the audio before applying the usual max_length padding
if waveform.shape[0] < max_length:
if padding == "repeat":
n_repeat = int(max_length / len(waveform))
waveform = np.stack(np.tile(waveform, n_repeat + 1))[:max_length]
if padding == "repeatpad":
n_repeat = int(max_length / len(waveform))
waveform = np.stack(np.tile(waveform, n_repeat))
waveform = np.pad(waveform, (0, max_length - waveform.shape[0]), mode="constant", constant_values=0)
if truncation == "fusion":
input_mel = self._np_extract_fbank_features(waveform, self.mel_filters)
input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0)
else:
input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :]
return input_mel, longer
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
truncation: str = None,
padding: Optional[str] = None,
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several sequence(s).
Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values.
truncation (`str`, *optional*):
Truncation pattern for long audio inputs. Two patterns are available:
- `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and
a downsampled version of the entire mel spectrogram.
If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a
copy of the original mel obtained from the padded audio.
- `rand_trunc` will select a random crop of the mel spectrogram.
padding (`str`, *optional*):
Padding pattern for shorter audio inputs. Three patterns were originally implemented:
- `repeatpad`: the audio is repeated, and then padded to fit the `max_length`.
- `repeat`: the audio is repeated and then cut to fit the `max_length`
- `pad`: the audio is padded.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.np.array` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*):
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
pipeline.
"""
truncation = truncation if truncation is not None else self.truncation
padding = padding if padding else self.padding
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
f" was sampled with {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
)
if is_batched:
raw_speech = [np.asarray(speech, dtype=np.float64) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float64)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
raw_speech = raw_speech.astype(np.float64)
# always return batch
if not is_batched:
raw_speech = [np.asarray(raw_speech)]
# convert to mel spectrogram, truncate and pad if needed.
padded_inputs = [
self._get_input_mel(waveform, max_length if max_length else self.nb_max_samples, truncation, padding)
for waveform in raw_speech
]
input_mel = []
is_longer = []
for mel, longer in padded_inputs:
input_mel.append(mel)
is_longer.append(longer)
if truncation == "fusion" and sum(is_longer) == 0:
# if no audio is longer than 10s, then randomly select one audio to be longer
rand_idx = np.random.randint(0, len(input_mel))
is_longer[rand_idx] = True
if isinstance(input_mel[0], List):
input_mel = [np.asarray(feature, dtype=np.float64) for feature in input_mel]
input_features = {"input_features": input_mel, "is_longer": is_longer}
input_features = BatchFeature(input_features)
if return_tensors is not None:
input_features = input_features.convert_to_tensors(return_tensors)
return input_features
This diff is collapsed.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Audio/Text processor class for CLAP
"""
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
class ClapProcessor(ProcessorMixin):
r"""
Constructs a CLAP processor which wraps a CLAP feature extractor and a RoBerta tokenizer into a single processor.
[`ClapProcessor`] offers all the functionalities of [`ClapFeatureExtractor`] and [`RobertaTokenizerFast`]. See the
[`~ClapProcessor.__call__`] and [`~ClapProcessor.decode`] for more information.
Args:
feature_extractor ([`ClapFeatureExtractor`]):
The audio processor is a required input.
tokenizer ([`RobertaTokenizerFast`]):
The tokenizer is a required input.
"""
feature_extractor_class = "ClapFeatureExtractor"
tokenizer_class = ("RobertaTokenizer", "RobertaTokenizerFast")
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
def __call__(self, text=None, audios=None, return_tensors=None, **kwargs):
"""
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
and `kwargs` arguments to RobertaTokenizerFast's [`~RobertaTokenizerFast.__call__`] if `text` is not `None` to
encode the text. To prepare the audio(s), this method forwards the `audios` and `kwrags` arguments to
ClapFeatureExtractor's [`~ClapFeatureExtractor.__call__`] if `audios` is not `None`. Please refer to the
doctsring of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case
of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels,
and T the sample length of the audio.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **audio_features** -- Audio features to be fed to a model. Returned when `audios` is not `None`.
"""
sampling_rate = kwargs.pop("sampling_rate", None)
if text is None and audios is None:
raise ValueError("You have to specify either text or audios. Both cannot be none.")
if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
if audios is not None:
audio_features = self.feature_extractor(
audios, sampling_rate=sampling_rate, return_tensors=return_tensors, **kwargs
)
if text is not None and audios is not None:
encoding["input_features"] = audio_features.input_features
return encoding
elif text is not None:
return encoding
else:
return BatchEncoding(data=dict(**audio_features), tensor_type=return_tensors)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to RobertaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to RobertaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer
to the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
...@@ -1464,6 +1464,58 @@ class ChineseCLIPVisionModel(metaclass=DummyObject): ...@@ -1464,6 +1464,58 @@ class ChineseCLIPVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
CLAP_PRETRAINED_MODEL_ARCHIVE_LIST = None
class ClapAudioModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ClapAudioModelWithProjection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ClapFeatureExtractor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ClapModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ClapPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ClapTextModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ClapTextModelWithProjection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import random
import unittest
import numpy as np
from transformers import ClapFeatureExtractor
from transformers.testing_utils import require_torch, require_torchaudio
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_torch_available():
import torch
global_rng = random.Random()
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
@require_torch
@require_torchaudio
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTester with Whisper->Clap
class ClapFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
min_seq_length=400,
max_seq_length=2000,
feature_size=10,
hop_length=160,
chunk_length=8,
padding_value=0.0,
sampling_rate=4_000,
return_attention_mask=False,
do_normalize=True,
):
self.parent = parent
self.batch_size = batch_size
self.min_seq_length = min_seq_length
self.max_seq_length = max_seq_length
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
self.padding_value = padding_value
self.sampling_rate = sampling_rate
self.return_attention_mask = return_attention_mask
self.do_normalize = do_normalize
self.feature_size = feature_size
self.chunk_length = chunk_length
self.hop_length = hop_length
def prepare_feat_extract_dict(self):
return {
"feature_size": self.feature_size,
"hop_length": self.hop_length,
"chunk_length": self.chunk_length,
"padding_value": self.padding_value,
"sampling_rate": self.sampling_rate,
"return_attention_mask": self.return_attention_mask,
"do_normalize": self.do_normalize,
}
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
def _flatten(list_of_lists):
return list(itertools.chain(*list_of_lists))
if equal_length:
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
else:
# make sure that inputs increase in size
speech_inputs = [
floats_list((x, self.feature_size))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
]
if numpify:
speech_inputs = [np.asarray(x) for x in speech_inputs]
return speech_inputs
@require_torch
@require_torchaudio
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest with Whisper->Clap
class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = ClapFeatureExtractor
def setUp(self):
self.feat_extract_tester = ClapFeatureExtractionTester(self)
def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
# Test feature size
input_features = feature_extractor(np_speech_inputs, padding="max_length", return_tensors="np").input_features
self.assertTrue(input_features.ndim == 4)
# Test not batched input
encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="np").input_features
encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").input_features
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
# Test batched
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_double_precision_pad(self):
import torch
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
py_speech_inputs = np_speech_inputs.tolist()
for inputs in [py_speech_inputs, np_speech_inputs]:
np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def integration_test_fusion(self):
# fmt: off
EXPECTED_INPUT_FEATURES = torch.tensor(
[
[
-30.2194, -22.4424, -18.6442, -17.2452, -22.7392, -32.2576, -36.1404,
-35.6120, -29.6229, -29.0454, -32.2157, -36.7664, -29.4436, -26.7825,
-31.1811, -38.3918, -38.8749, -43.4485, -47.6236, -38.7528, -31.8574,
-39.0591, -41.3190, -32.3319, -31.4699, -33.4502, -36.7412, -34.5265,
-35.1091, -40.4518, -42.7346, -44.5909, -44.9747, -45.8328, -47.0772,
-46.2723, -44.3613, -48.6253, -44.9551, -43.8700, -44.6104, -48.0146,
-42.7614, -47.3587, -47.4369, -45.5018, -47.0198, -42.8759, -47.5056,
-47.1567, -49.2621, -49.5643, -48.4330, -48.8495, -47.2512, -40.8439,
-48.1234, -49.1218, -48.7222, -50.2399, -46.8487, -41.9921, -50.4015,
-50.7827
],
[
-89.0141, -89.1411, -88.8096, -88.5480, -88.3481, -88.2038,
-88.1105, -88.0647, -88.0636, -88.1051, -88.1877, -88.1110,
-87.8613, -88.6679, -88.2685, -88.9684, -88.7977, -89.6264,
-89.9299, -90.3184, -91.1446, -91.9265, -92.7267, -93.6099,
-94.6395, -95.3243, -95.5923, -95.5773, -95.0889, -94.3354,
-93.5746, -92.9287, -92.4525, -91.9798, -91.8852, -91.7500,
-91.7259, -91.7561, -91.7959, -91.7070, -91.6914, -91.5019,
-91.0640, -90.0807, -88.7102, -87.0826, -85.5956, -84.4441,
-83.8461, -83.8605, -84.6702, -86.3900, -89.3073, -93.2926,
-96.3813, -97.3529, -100.0000, -99.6942, -92.2851, -87.9588,
-85.7214, -84.6807, -84.1940, -84.2021
],
[
-51.6882, -50.6852, -50.8198, -51.7428, -53.0325, -54.1619, -56.4903,
-59.0314, -60.7996, -60.5164, -59.9680, -60.5393, -62.5796, -65.4166,
-65.6149, -65.1409, -65.7226, -67.9057, -72.5089, -82.3530, -86.3189,
-83.4241, -79.1279, -79.3384, -82.7335, -79.8316, -80.2167, -74.3638,
-71.3930, -75.3849, -74.5381, -71.4504, -70.3791, -71.4547, -71.8820,
-67.3885, -69.5686, -71.9852, -71.0307, -73.0053, -80.8802, -72.9227,
-63.8526, -60.3260, -59.6012, -57.8316, -61.0603, -67.3403, -67.1709,
-60.4967, -60.5079, -68.3345, -67.5213, -70.6416, -79.6219, -78.2198,
-74.6851, -69.5718, -69.4968, -70.6882, -66.8175, -73.8558, -74.3855,
-72.9405
]
]
)
# fmt: on
MEL_BIN = [963, 963, 161]
input_speech = self._load_datasamples(1)
feaure_extractor = ClapFeatureExtractor()
for padding, EXPECTED_VALUES, idx_in_mel in zip(
["repeat", "repeatpad", None], EXPECTED_INPUT_FEATURES, MEL_BIN
):
input_features = feaure_extractor(input_speech, return_tensors="pt", padding=padding).input_features
self.assertTrue(torch.allclose(input_features[0, idx_in_mel], EXPECTED_VALUES, atol=1e-4))
def integration_test_rand_trunc(self):
# TODO in this case we should set the seed and use a longer audio to properly see the random truncation
# fmt: off
EXPECTED_INPUT_FEATURES = torch.tensor(
[
[
-42.3330, -36.2735, -35.9231, -43.5947, -48.4525, -46.5227, -42.6477,
-47.2740, -51.4336, -50.0846, -51.8711, -50.4232, -47.4736, -54.2275,
-53.3947, -55.4904, -54.8750, -54.5510, -55.4156, -57.4395, -51.7385,
-55.9118, -57.7800, -63.2064, -67.0651, -61.4379, -56.4268, -54.8667,
-52.3487, -56.4418, -57.1842, -55.1005, -55.6366, -59.4395, -56.8604,
-56.4949, -61.6573, -61.0826, -60.3250, -63.7876, -67.4882, -60.2323,
-54.6886, -50.5369, -47.7656, -45.8909, -49.1273, -57.4141, -58.3201,
-51.9862, -51.4897, -59.2561, -60.4730, -61.2203, -69.3174, -69.7464,
-65.5861, -58.9921, -59.5610, -61.0584, -58.1149, -64.4045, -66.2622,
-64.4610
],
[
-41.2298, -38.4211, -39.8834, -45.9950, -47.3839, -43.9849, -46.0371,
-52.5490, -56.6912, -51.8794, -50.1284, -49.7506, -53.9422, -63.2854,
-56.5754, -55.0469, -55.3181, -55.8115, -56.0058, -57.9215, -58.7597,
-59.1994, -59.2141, -64.4198, -73.5138, -64.4647, -59.3351, -54.5626,
-54.7508, -65.0230, -60.0270, -54.7644, -56.0108, -60.1531, -57.6879,
-56.3766, -63.3395, -65.3032, -61.5202, -63.0677, -68.4217, -60.6868,
-54.4619, -50.8533, -47.7200, -45.9197, -49.0961, -57.7621, -59.0750,
-51.9122, -51.4332, -59.4132, -60.3415, -61.6558, -70.7049, -69.7905,
-66.9104, -59.0324, -59.6138, -61.2023, -58.2169, -65.3837, -66.4425,
-64.4142
],
[
-51.6882, -50.6852, -50.8198, -51.7428, -53.0325, -54.1619, -56.4903,
-59.0314, -60.7996, -60.5164, -59.9680, -60.5393, -62.5796, -65.4166,
-65.6149, -65.1409, -65.7226, -67.9057, -72.5089, -82.3530, -86.3189,
-83.4241, -79.1279, -79.3384, -82.7335, -79.8316, -80.2167, -74.3638,
-71.3930, -75.3849, -74.5381, -71.4504, -70.3791, -71.4547, -71.8820,
-67.3885, -69.5686, -71.9852, -71.0307, -73.0053, -80.8802, -72.9227,
-63.8526, -60.3260, -59.6012, -57.8316, -61.0603, -67.3403, -67.1709,
-60.4967, -60.5079, -68.3345, -67.5213, -70.6416, -79.6219, -78.2198,
-74.6851, -69.5718, -69.4968, -70.6882, -66.8175, -73.8558, -74.3855,
-72.9405
]
]
)
# fmt: on
input_speech = self._load_datasamples(1)
feaure_extractor = ClapFeatureExtractor()
for padding, EXPECTED_VALUES in zip(["repeat", "repeatpad", None], EXPECTED_INPUT_FEATURES):
input_features = feaure_extractor(
input_speech, return_tensors="pt", truncation="rand_trunc", padding=padding
).input_features
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_VALUES, atol=1e-4))
This diff is collapsed.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from transformers import ClapFeatureExtractor, ClapProcessor, RobertaTokenizer, RobertaTokenizerFast
from transformers.testing_utils import require_sentencepiece, require_torchaudio
from .test_feature_extraction_clap import floats_list
@require_torchaudio
@require_sentencepiece
class ClapProcessorTest(unittest.TestCase):
def setUp(self):
self.checkpoint = "laion/clap-htsat-unfused"
self.tmpdirname = tempfile.mkdtemp()
def get_tokenizer(self, **kwargs):
return RobertaTokenizer.from_pretrained(self.checkpoint, **kwargs)
def get_feature_extractor(self, **kwargs):
return ClapFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
processor = ClapProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
processor.save_pretrained(self.tmpdirname)
processor = ClapProcessor.from_pretrained(self.tmpdirname)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, RobertaTokenizerFast)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
self.assertIsInstance(processor.feature_extractor, ClapFeatureExtractor)
def test_save_load_pretrained_additional_features(self):
processor = ClapProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
processor = ClapProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, RobertaTokenizerFast)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, ClapFeatureExtractor)
def test_feature_extractor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = ClapProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
raw_speech = floats_list((3, 1000))
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
input_processor = processor(audios=raw_speech, return_tensors="np")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = ClapProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "This is a test string"
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = ClapProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
decoded_processor = processor.batch_decode(predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = ClapProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names[2:],
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
...@@ -172,6 +172,10 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ ...@@ -172,6 +172,10 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# should **not** be the rule. # should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping # models to ignore for model xxx mapping
"ClapTextModel",
"ClapTextModelWithProjection",
"ClapAudioModel",
"ClapAudioModelWithProjection",
"Blip2ForConditionalGeneration", "Blip2ForConditionalGeneration",
"Blip2QFormerModel", "Blip2QFormerModel",
"Blip2VisionModel", "Blip2VisionModel",
......
...@@ -40,6 +40,8 @@ src/transformers/models/bloom/configuration_bloom.py ...@@ -40,6 +40,8 @@ src/transformers/models/bloom/configuration_bloom.py
src/transformers/models/camembert/configuration_camembert.py src/transformers/models/camembert/configuration_camembert.py
src/transformers/models/canine/configuration_canine.py src/transformers/models/canine/configuration_canine.py
src/transformers/models/canine/modeling_canine.py src/transformers/models/canine/modeling_canine.py
src/transformers/models/clap/configuration_clap.py
src/transformers/models/clap/modeling_clap.py
src/transformers/models/clip/configuration_clip.py src/transformers/models/clip/configuration_clip.py
src/transformers/models/clipseg/modeling_clipseg.py src/transformers/models/clipseg/modeling_clipseg.py
src/transformers/models/codegen/configuration_codegen.py src/transformers/models/codegen/configuration_codegen.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment