"docs/source/vscode:/vscode.git/clone" did not exist on "6b241e0e3bda24546d15835e5e0b48b8d1e4732c"
Unverified Commit 450a181d authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

Add Pop2Piano (#21785)



* init commit

* config updated also some modeling

* Processor and Model config combined

* extraction pipeline(upto before spectogram & mel_conditioner) added but not properly tested

* model loading successful!

* feature extractor done!

* FE can now be called from HF

* postprocessing added in fe file

* same as prev commit

* Pop2PianoConfig doc done

* cfg docs slightly changed

* fe docs done

* batched

* batched working!

* temp

* v1

* checking

* trying to go with generate

* with generate and model tests passed

* before rebasing

* .

* tests done docs done remaining others & nits

* nits

* LogMelSpectogram shifted to FeatureExtractor

* is_tf rmeoved from pop2piano/init

* import solved

* tokenization tests added

* minor fixed regarding modeling_pop2piano

* tokenizer changed to only return midi_object and other changes

* Updated paper abstract(Camera-ready version) (#2)

* more comments and nits

* ruff changes

* code quality fix

* sg comments

* t5 change added and rebased

* comments except batching

* batching done

* comments

* small doc fix

* example removed from modeling

* ckpt

* forward it compatible with fe and generation done

* comments

* comments

* code-quality fix(maybe)

* ckpts changed

* doc file changed from mdx to md

* test fixes

* tokenizer test fix

* changes

* nits done main changes remaining

* code modified

* Pop2PianoProcessor added with tests

* other comments

* added Pop2PianoProcessor to dummy_objects

* added require_onnx to modeling file

* changes

* update .md file

* remove extra line in index.md

* back to the main index

* added pop2piano to index

* Added tokenizer.__call__ with valid args and batch_decode and aligned the processor part too

* changes

* added return types to 2 tokenizer methods

* the PR build test might work now

* added backends

* PR build fix

* vocab added

* comments

* refactored vocab into 1 file

* added conversion script

* comments

* essentia version changed in .md

* comments

* more tokenizer tests added

* minor fix

* tests extended for outputs acc check

* small fix

---------
Co-authored-by: default avatarJongho Choi <sweetcocoa@snu.ac.kr>
parent 6f041fcb
This diff is collapsed.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Processor class for Pop2Piano."""
import os
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin
from ...tokenization_utils import BatchEncoding, PaddingStrategy, TruncationStrategy
from ...utils import TensorType
class Pop2PianoProcessor(ProcessorMixin):
r"""
Constructs an Pop2Piano processor which wraps a Pop2Piano Feature Extractor and Pop2Piano Tokenizer into a single
processor.
[`Pop2PianoProcessor`] offers all the functionalities of [`Pop2PianoFeatureExtractor`] and [`Pop2PianoTokenizer`].
See the docstring of [`~Pop2PianoProcessor.__call__`] and [`~Pop2PianoProcessor.decode`] for more information.
Args:
feature_extractor (`Pop2PianoFeatureExtractor`):
An instance of [`Pop2PianoFeatureExtractor`]. The feature extractor is a required input.
tokenizer (`Pop2PianoTokenizer`):
An instance of ['Pop2PianoTokenizer`]. The tokenizer is a required input.
"""
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "Pop2PianoFeatureExtractor"
tokenizer_class = "Pop2PianoTokenizer"
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
def __call__(
self,
audio: Union[np.ndarray, List[float], List[np.ndarray]] = None,
sampling_rate: Union[int, List[int]] = None,
steps_per_beat: int = 2,
resample: Optional[bool] = True,
notes: Union[List, TensorType] = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
verbose: bool = True,
**kwargs,
) -> Union[BatchFeature, BatchEncoding]:
"""
This method uses [`Pop2PianoFeatureExtractor.__call__`] method to prepare log-mel-spectrograms for the model,
and [`Pop2PianoTokenizer.__call__`] to prepare token_ids from notes.
Please refer to the docstring of the above two methods for more information.
"""
# Since Feature Extractor needs both audio and sampling_rate and tokenizer needs both token_ids and
# feature_extractor_output, we must check for both.
if (audio is None and sampling_rate is None) and (notes is None):
raise ValueError(
"You have to specify at least audios and sampling_rate in order to use feature extractor or "
"notes to use the tokenizer part."
)
if audio is not None and sampling_rate is not None:
inputs = self.feature_extractor(
audio=audio,
sampling_rate=sampling_rate,
steps_per_beat=steps_per_beat,
resample=resample,
**kwargs,
)
if notes is not None:
encoded_token_ids = self.tokenizer(
notes=notes,
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
if notes is None:
return inputs
elif audio is None or sampling_rate is None:
return encoded_token_ids
else:
inputs["token_ids"] = encoded_token_ids["token_ids"]
return inputs
def batch_decode(
self,
token_ids,
feature_extractor_output: BatchFeature,
return_midi: bool = True,
) -> BatchEncoding:
"""
This method uses [`Pop2PianoTokenizer.batch_decode`] method to convert model generated token_ids to midi_notes.
Please refer to the docstring of the above two methods for more information.
"""
return self.tokenizer.batch_decode(
token_ids=token_ids, feature_extractor_output=feature_extractor_output, return_midi=return_midi
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
def save_pretrained(self, save_directory, **kwargs):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
return super().save_pretrained(save_directory, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(*args)
This diff is collapsed.
...@@ -32,6 +32,7 @@ is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse( ...@@ -32,6 +32,7 @@ is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse(
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11") is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11")
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11") is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
is_torch_1_8_0 = parsed_torch_version_base == version.parse("1.8.0")
def softmax_backward_data(parent, grad_output, output, dim, self): def softmax_backward_data(parent, grad_output, output, dim, self):
......
...@@ -57,6 +57,7 @@ from .utils import ( ...@@ -57,6 +57,7 @@ from .utils import (
is_cython_available, is_cython_available,
is_decord_available, is_decord_available,
is_detectron2_available, is_detectron2_available,
is_essentia_available,
is_faiss_available, is_faiss_available,
is_flax_available, is_flax_available,
is_ftfy_available, is_ftfy_available,
...@@ -71,6 +72,7 @@ from .utils import ( ...@@ -71,6 +72,7 @@ from .utils import (
is_pandas_available, is_pandas_available,
is_peft_available, is_peft_available,
is_phonemizer_available, is_phonemizer_available,
is_pretty_midi_available,
is_pyctcdecode_available, is_pyctcdecode_available,
is_pytesseract_available, is_pytesseract_available,
is_pytest_available, is_pytest_available,
...@@ -825,6 +827,20 @@ def require_librosa(test_case): ...@@ -825,6 +827,20 @@ def require_librosa(test_case):
return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case) return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
def require_essentia(test_case):
"""
Decorator marking a test that requires essentia
"""
return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case)
def require_pretty_midi(test_case):
"""
Decorator marking a test that requires pretty_midi
"""
return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case)
def cmd_exists(cmd): def cmd_exists(cmd):
return shutil.which(cmd) is not None return shutil.which(cmd) is not None
......
...@@ -112,6 +112,7 @@ from .import_utils import ( ...@@ -112,6 +112,7 @@ from .import_utils import (
is_datasets_available, is_datasets_available,
is_decord_available, is_decord_available,
is_detectron2_available, is_detectron2_available,
is_essentia_available,
is_faiss_available, is_faiss_available,
is_flax_available, is_flax_available,
is_ftfy_available, is_ftfy_available,
...@@ -130,6 +131,7 @@ from .import_utils import ( ...@@ -130,6 +131,7 @@ from .import_utils import (
is_pandas_available, is_pandas_available,
is_peft_available, is_peft_available,
is_phonemizer_available, is_phonemizer_available,
is_pretty_midi_available,
is_protobuf_available, is_protobuf_available,
is_psutil_available, is_psutil_available,
is_py3nvml_available, is_py3nvml_available,
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class Pop2PianoFeatureExtractor(metaclass=DummyObject):
_backends = ["essentia", "librosa", "pretty_midi", "scipy", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["essentia", "librosa", "pretty_midi", "scipy", "torch"])
class Pop2PianoTokenizer(metaclass=DummyObject):
_backends = ["essentia", "librosa", "pretty_midi", "scipy", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["essentia", "librosa", "pretty_midi", "scipy", "torch"])
class Pop2PianoProcessor(metaclass=DummyObject):
_backends = ["essentia", "librosa", "pretty_midi", "scipy", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["essentia", "librosa", "pretty_midi", "scipy", "torch"])
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class Pop2PianoFeatureExtractor(metaclass=DummyObject):
_backends = ["music"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["music"])
class Pop2PianoTokenizer(metaclass=DummyObject):
_backends = ["music"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["music"])
...@@ -5935,6 +5935,23 @@ class PoolFormerPreTrainedModel(metaclass=DummyObject): ...@@ -5935,6 +5935,23 @@ class PoolFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
POP2PIANO_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Pop2PianoForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Pop2PianoPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -185,6 +185,22 @@ else: ...@@ -185,6 +185,22 @@ else:
logger.info("Disabling Tensorflow because USE_TORCH is set") logger.info("Disabling Tensorflow because USE_TORCH is set")
_essentia_available = importlib.util.find_spec("essentia") is not None
try:
_essentia_version = importlib.metadata.version("essentia")
logger.debug(f"Successfully imported essentia version {_essentia_version}")
except importlib.metadata.PackageNotFoundError:
_essentia_version = False
_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
try:
_pretty_midi_version = importlib.metadata.version("pretty_midi")
logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}")
except importlib.metadata.PackageNotFoundError:
_pretty_midi_available = False
ccl_version = "N/A" ccl_version = "N/A"
_is_ccl_available = ( _is_ccl_available = (
importlib.util.find_spec("torch_ccl") is not None importlib.util.find_spec("torch_ccl") is not None
...@@ -242,6 +258,14 @@ def is_librosa_available(): ...@@ -242,6 +258,14 @@ def is_librosa_available():
return _librosa_available return _librosa_available
def is_essentia_available():
return _essentia_available
def is_pretty_midi_available():
return _pretty_midi_available
def is_torch_cuda_available(): def is_torch_cuda_available():
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -986,6 +1010,27 @@ CCL_IMPORT_ERROR = """ ...@@ -986,6 +1010,27 @@ CCL_IMPORT_ERROR = """
Please note that you may need to restart your runtime after installation. Please note that you may need to restart your runtime after installation.
""" """
# docstyle-ignore
ESSENTIA_IMPORT_ERROR = """
{0} requires essentia library. But that was not found in your environment. You can install them with pip:
`pip install essentia==2.1b6.dev1034`
Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
LIBROSA_IMPORT_ERROR = """
{0} requires thes librosa library. But that was not found in your environment. You can install them with pip:
`pip install librosa`
Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
PRETTY_MIDI_IMPORT_ERROR = """
{0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip:
`pip install pretty_midi`
Please note that you may need to restart your runtime after installation.
"""
DECORD_IMPORT_ERROR = """ DECORD_IMPORT_ERROR = """
{0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install {0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install
decord`. Please note that you may need to restart your runtime after installation. decord`. Please note that you may need to restart your runtime after installation.
...@@ -1011,11 +1056,14 @@ BACKENDS_MAPPING = OrderedDict( ...@@ -1011,11 +1056,14 @@ BACKENDS_MAPPING = OrderedDict(
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)), ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)), ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)), ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)), ("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)), ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)), ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
from datasets import load_dataset
from transformers.testing_utils import (
check_json_file_has_correct_format,
require_essentia,
require_librosa,
require_scipy,
require_tf,
require_torch,
)
from transformers.utils.import_utils import (
is_essentia_available,
is_librosa_available,
is_scipy_available,
is_torch_available,
)
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
requirements_available = (
is_torch_available() and is_essentia_available() and is_scipy_available() and is_librosa_available()
)
if requirements_available:
import torch
from transformers import Pop2PianoFeatureExtractor
class Pop2PianoFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
n_bars=2,
sample_rate=22050,
use_mel=True,
padding_value=0,
vocab_size_special=4,
vocab_size_note=128,
vocab_size_velocity=2,
vocab_size_time=100,
):
self.parent = parent
self.n_bars = n_bars
self.sample_rate = sample_rate
self.use_mel = use_mel
self.padding_value = padding_value
self.vocab_size_special = vocab_size_special
self.vocab_size_note = vocab_size_note
self.vocab_size_velocity = vocab_size_velocity
self.vocab_size_time = vocab_size_time
def prepare_feat_extract_dict(self):
return {
"n_bars": self.n_bars,
"sample_rate": self.sample_rate,
"use_mel": self.use_mel,
"padding_value": self.padding_value,
"vocab_size_special": self.vocab_size_special,
"vocab_size_note": self.vocab_size_note,
"vocab_size_velocity": self.vocab_size_velocity,
"vocab_size_time": self.vocab_size_time,
}
@require_torch
@require_essentia
@require_librosa
@require_scipy
class Pop2PianoFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = Pop2PianoFeatureExtractor if requirements_available else None
def setUp(self):
self.feat_extract_tester = Pop2PianoFeatureExtractionTester(self)
def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = feat_extract_first.use_mel
mel_2 = feat_extract_second.use_mel
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_feat_extract_to_json_file(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
feat_extract_first.to_json_file(json_file_path)
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = feat_extract_first.use_mel
mel_2 = feat_extract_second.use_mel
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_call(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input = np.zeros([1000000], dtype=np.float32)
input_features = feature_extractor(speech_input, sampling_rate=16_000, return_tensors="np")
self.assertTrue(input_features.input_features.ndim == 3)
self.assertEqual(input_features.input_features.shape[-1], 512)
self.assertTrue(input_features.beatsteps.ndim == 2)
self.assertTrue(input_features.extrapolated_beatstep.ndim == 2)
def test_integration(self):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select([0])["audio"]
input_speech = [x["array"] for x in speech_samples][0]
sampling_rate = [x["sampling_rate"] for x in speech_samples][0]
feaure_extractor = Pop2PianoFeatureExtractor.from_pretrained("sweetcocoa/pop2piano")
input_features = feaure_extractor(
input_speech, sampling_rate=sampling_rate, return_tensors="pt"
).input_features
EXPECTED_INPUT_FEATURES = torch.tensor(
[[-7.1493, -6.8701, -4.3214], [-5.9473, -5.7548, -3.8438], [-6.1324, -5.9018, -4.3778]]
)
self.assertTrue(torch.allclose(input_features[0, :3, :3], EXPECTED_INPUT_FEATURES, atol=1e-4))
def test_attention_mask(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2],
sampling_rate=[44_100, 16_000],
return_tensors="np",
return_attention_mask=True,
)
self.assertTrue(hasattr(input_features, "attention_mask"))
# check shapes
self.assertTrue(input_features["attention_mask"].ndim == 2)
self.assertEqual(input_features["attention_mask_beatsteps"].shape[0], 2)
self.assertEqual(input_features["attention_mask_extrapolated_beatstep"].shape[0], 2)
# check if they are any values except 0 and 1
self.assertTrue(np.max(input_features["attention_mask"]) == 1)
self.assertTrue(np.max(input_features["attention_mask_beatsteps"]) == 1)
self.assertTrue(np.max(input_features["attention_mask_extrapolated_beatstep"]) == 1)
self.assertTrue(np.min(input_features["attention_mask"]) == 0)
self.assertTrue(np.min(input_features["attention_mask_beatsteps"]) == 0)
self.assertTrue(np.min(input_features["attention_mask_extrapolated_beatstep"]) == 0)
def test_batch_feature(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_attention_mask=True,
)
self.assertEqual(len(input_features["input_features"].shape), 3)
# check shape
self.assertEqual(input_features["beatsteps"].shape[0], 3)
self.assertEqual(input_features["extrapolated_beatstep"].shape[0], 3)
def test_batch_feature_np(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_tensors="np",
return_attention_mask=True,
)
# check np array or not
self.assertEqual(type(input_features["input_features"]), np.ndarray)
# check shape
self.assertEqual(len(input_features["input_features"].shape), 3)
def test_batch_feature_pt(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_tensors="pt",
return_attention_mask=True,
)
# check pt tensor or not
self.assertEqual(type(input_features["input_features"]), torch.Tensor)
# check shape
self.assertEqual(len(input_features["input_features"].shape), 3)
@require_tf
def test_batch_feature_tf(self):
import tensorflow as tf
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_tensors="tf",
return_attention_mask=True,
)
# check tf tensor or not
self.assertTrue(tf.is_tensor(input_features["input_features"]))
# check shape
self.assertEqual(len(input_features["input_features"].shape), 3)
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_accepts_tensors_pt(self):
pass
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_accepts_tensors_tf(self):
pass
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_from_list(self):
pass
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_from_array(self):
pass
@unittest.skip("Pop2PianoFeatureExtractor does not support truncation")
def test_attention_mask_with_truncation(self):
pass
@unittest.skip("Pop2PianoFeatureExtractor does not supports truncation")
def test_truncation_from_array(self):
pass
@unittest.skip("Pop2PianoFeatureExtractor does not supports truncation")
def test_truncation_from_list(self):
pass
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -58,6 +58,8 @@ SPECIAL_CASES_TO_ALLOW = { ...@@ -58,6 +58,8 @@ SPECIAL_CASES_TO_ALLOW = {
# used internally in the configuration class file # used internally in the configuration class file
"LongT5Config": ["feed_forward_proj"], "LongT5Config": ["feed_forward_proj"],
# used internally in the configuration class file # used internally in the configuration class file
"Pop2PianoConfig": ["feed_forward_proj"],
# used internally in the configuration class file
"SwitchTransformersConfig": ["feed_forward_proj"], "SwitchTransformersConfig": ["feed_forward_proj"],
# having default values other than `1e-5` - we can't fix them without breaking # having default values other than `1e-5` - we can't fix them without breaking
"BioGptConfig": ["layer_norm_eps"], "BioGptConfig": ["layer_norm_eps"],
......
...@@ -66,6 +66,7 @@ PRIVATE_MODELS = [ ...@@ -66,6 +66,7 @@ PRIVATE_MODELS = [
"T5Stack", "T5Stack",
"MT5Stack", "MT5Stack",
"UMT5Stack", "UMT5Stack",
"Pop2PianoStack",
"SwitchTransformersStack", "SwitchTransformersStack",
"TFDPRSpanPredictor", "TFDPRSpanPredictor",
"MaskFormerSwinModel", "MaskFormerSwinModel",
......
...@@ -346,6 +346,8 @@ src/transformers/models/poolformer/configuration_poolformer.py ...@@ -346,6 +346,8 @@ src/transformers/models/poolformer/configuration_poolformer.py
src/transformers/models/poolformer/feature_extraction_poolformer.py src/transformers/models/poolformer/feature_extraction_poolformer.py
src/transformers/models/poolformer/image_processing_poolformer.py src/transformers/models/poolformer/image_processing_poolformer.py
src/transformers/models/poolformer/modeling_poolformer.py src/transformers/models/poolformer/modeling_poolformer.py
src/transformers/models/pop2piano/configuration_pop2piano.py
src/transformers/models/pop2piano/modeling_pop2piano.py
src/transformers/models/prophetnet/tokenization_prophetnet.py src/transformers/models/prophetnet/tokenization_prophetnet.py
src/transformers/models/rag/tokenization_rag.py src/transformers/models/rag/tokenization_rag.py
src/transformers/models/realm/configuration_realm.py src/transformers/models/realm/configuration_realm.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment