Commit c5b8e585 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add unit test for LibriMix dataset (#2659)

Summary:
Besides the unit test, the PR also addresses these issues:
- The original `LibriMix` dataset only supports "min" mode, which means the audio length is the minimum of all clean sources. It is default for source separation task. Users may also want to use "max" mode which allows for end-to-end separation and recognition. The PR adds ``mode`` argument to let users decide which dataset they want to use.
- If the task is ``"enh_both"``, the target is the audios in ``mix_clean`` instead of separate clean sources. The PR fixes it to use ``mix_clean`` as target.

Pull Request resolved: https://github.com/pytorch/audio/pull/2659

Reviewed By: carolineechen

Differential Revision: D40229227

Pulled By: nateanl

fbshipit-source-id: fc07e0d88a245e1367656d3767cf98168a799235
parent be938e7e
import os
from parameterized import parameterized
from torchaudio.datasets import LibriMix
from torchaudio_unittest.common_utils import get_whitenoise, save_wav, TempDirMixin, TorchaudioTestCase
_SAMPLE_RATE = 8000
_TASKS_TO_MIXTURE = {
"sep_clean": "mix_clean",
"enh_single": "mix_single",
"enh_both": "mix_both",
"sep_noisy": "mix_both",
}
def _save_wav(filepath: str, seed: int):
wav = get_whitenoise(
sample_rate=_SAMPLE_RATE,
duration=0.01,
n_channels=1,
seed=seed,
)
save_wav(filepath, wav, _SAMPLE_RATE)
return wav
def get_mock_dataset(root_dir: str, num_speaker: int):
"""
root_dir: directory to the mocked dataset
"""
mocked_data = []
seed = 0
base_dir = os.path.join(root_dir, f"Libri{num_speaker}Mix", "wav8k", "min", "train-360")
os.makedirs(base_dir, exist_ok=True)
for utterance_id in range(10):
filename = f"{utterance_id}.wav"
task_outputs = {}
for task in _TASKS_TO_MIXTURE:
# create mixture folder. The folder names depends on the task.
mixture_folder = _TASKS_TO_MIXTURE[task]
mixture_dir = os.path.join(base_dir, mixture_folder)
os.makedirs(mixture_dir, exist_ok=True)
mixture_path = os.path.join(mixture_dir, filename)
mixture = _save_wav(mixture_path, seed)
sources = []
if task == "enh_both":
sources = [task_outputs["sep_clean"][1]]
else:
for speaker_id in range(num_speaker):
source_dir = os.path.join(base_dir, f"s{speaker_id+1}")
os.makedirs(source_dir, exist_ok=True)
source_path = os.path.join(source_dir, filename)
if os.path.exists(source_path):
sources = task_outputs["sep_clean"][2]
break
else:
source = _save_wav(source_path, seed)
sources.append(source)
seed += 1
task_outputs[task] = (_SAMPLE_RATE, mixture, sources)
mocked_data.append(task_outputs)
return mocked_data
class TestLibriMix(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
samples_2spk = []
samples_3spk = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.samples_2spk = get_mock_dataset(cls.root_dir, 2)
cls.samples_3spk = get_mock_dataset(cls.root_dir, 3)
def _test_librimix(self, dataset, samples, task):
num_samples = 0
for i, (sample_rate, mixture, sources) in enumerate(dataset):
assert sample_rate == samples[i][task][0]
self.assertEqual(mixture, samples[i][task][1])
assert len(sources) == len(samples[i][task][2])
for j in range(len(sources)):
self.assertEqual(sources[j], samples[i][task][2][j])
num_samples += 1
assert num_samples == len(samples)
@parameterized.expand([("sep_clean"), ("enh_single",), ("enh_both",), ("sep_noisy",)])
def test_librimix_2speaker(self, task):
dataset = LibriMix(self.root_dir, num_speakers=2, sample_rate=_SAMPLE_RATE, task=task)
self._test_librimix(dataset, self.samples_2spk, task)
@parameterized.expand([("sep_clean"), ("enh_single",), ("enh_both",), ("sep_noisy",)])
def test_librimix_3speaker(self, task):
dataset = LibriMix(self.root_dir, num_speakers=3, sample_rate=_SAMPLE_RATE, task=task)
self._test_librimix(dataset, self.samples_3spk, task)
...@@ -7,6 +7,13 @@ from torch.utils.data import Dataset ...@@ -7,6 +7,13 @@ from torch.utils.data import Dataset
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]] SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
_TASKS_TO_MIXTURE = {
"sep_clean": "mix_clean",
"enh_single": "mix_single",
"enh_both": "mix_both",
"sep_noisy": "mix_both",
}
class LibriMix(Dataset): class LibriMix(Dataset):
r"""*LibriMix* :cite:`cosentino2020librimix` dataset. r"""*LibriMix* :cite:`cosentino2020librimix` dataset.
...@@ -19,12 +26,17 @@ class LibriMix(Dataset): ...@@ -19,12 +26,17 @@ class LibriMix(Dataset):
num_speakers (int, optional): The number of speakers, which determines the directories num_speakers (int, optional): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios. (Default: 2) N source audios. (Default: 2)
sample_rate (int, optional): sample rate of audio files. The ``sample_rate`` determines sample_rate (int, optional): Sample rate of audio files. The ``sample_rate`` determines
which subdirectory the audio are fetched. If any of the audio has a different sample which subdirectory the audio are fetched. If any of the audio has a different sample
rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000) rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)
task (str, optional): the task of LibriMix. task (str, optional): The task of LibriMix.
Options: [``"enh_single"``, ``"enh_both"``, ``"sep_clean"``, ``"sep_noisy"``] Options: [``"enh_single"``, ``"enh_both"``, ``"sep_clean"``, ``"sep_noisy"``]
(Default: ``"sep_clean"``) (Default: ``"sep_clean"``)
mode (str, optional): The mode when creating the mixture. If set to ``"min"``, the lengths of mixture
and sources are the minimum length of all sources. If set to ``"max"``, the lengths of mixture and
sources are zero padded to the maximum length of all sources.
Options: [``"min"``, ``"max"``]
(Default: ``"min"``)
Note: Note:
The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
...@@ -37,20 +49,26 @@ class LibriMix(Dataset): ...@@ -37,20 +49,26 @@ class LibriMix(Dataset):
num_speakers: int = 2, num_speakers: int = 2,
sample_rate: int = 8000, sample_rate: int = 8000,
task: str = "sep_clean", task: str = "sep_clean",
mode: str = "min",
): ):
self.root = Path(root) / f"Libri{num_speakers}Mix" self.root = Path(root) / f"Libri{num_speakers}Mix"
if mode not in ["max", "min"]:
raise ValueError(f'Expect ``mode`` to be one in ["min", "max"]. Found {mode}.')
if sample_rate == 8000: if sample_rate == 8000:
self.root = self.root / "wav8k/min" / subset self.root = self.root / "wav8k" / mode / subset
elif sample_rate == 16000: elif sample_rate == 16000:
self.root = self.root / "wav16k/min" / subset self.root = self.root / "wav16k" / mode / subset
else: else:
raise ValueError(f"Unsupported sample rate. Found {sample_rate}.") raise ValueError(f"Unsupported sample rate. Found {sample_rate}.")
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.task = task self.task = task
self.mix_dir = (self.root / f"mix_{task.split('_')[1]}").resolve() self.mix_dir = (self.root / _TASKS_TO_MIXTURE[task]).resolve()
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)] if task == "enh_both":
self.src_dirs = [(self.root / "mix_clean")]
else:
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
self.files = [p.name for p in self.mix_dir.glob("*wav")] self.files = [p.name for p in self.mix_dir.glob("*.wav")]
self.files.sort() self.files.sort()
def _load_audio(self, path) -> torch.Tensor: def _load_audio(self, path) -> torch.Tensor:
...@@ -80,6 +98,7 @@ class LibriMix(Dataset): ...@@ -80,6 +98,7 @@ class LibriMix(Dataset):
Args: Args:
key (int): The index of the sample to be loaded key (int): The index of the sample to be loaded
Returns: Returns:
Tuple of the following items; Tuple of the following items;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment