Commit dc0990c7 authored by nateanl's avatar nateanl Committed by Zhaoheng Ni
Browse files

[Cherry-picked 0.10] Move LibriMix dataset to datasets directory (#1833)

parent e6fccfda
from argparse import ArgumentParser from argparse import ArgumentParser
import pathlib from pathlib import Path
from lightning_train import _get_model, _get_dataloader, sisdri_metric from lightning_train import _get_model, _get_dataloader, sisdri_metric
import mir_eval import mir_eval
import torch import torch
def eval(model, data_loader, device): def _eval(model, data_loader, device):
results = torch.zeros(4) results = torch.zeros(4)
with torch.no_grad(): with torch.no_grad():
for i, batch in enumerate(data_loader): for _, batch in enumerate(data_loader):
mix, src, mask = batch mix, src, mask = batch
mix, src, mask = mix.to(device), src.to(device), mask.to(device) mix, src, mask = mix.to(device), src.to(device), mask.to(device)
est = model(mix) est = model(mix)
...@@ -35,7 +35,11 @@ def eval(model, data_loader, device): ...@@ -35,7 +35,11 @@ def eval(model, data_loader, device):
def cli_main(): def cli_main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"]) parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"])
parser.add_argument("--data-dir", default=pathlib.Path("./Libri2Mix/wav8k/min"), type=pathlib.Path) parser.add_argument(
"--root-dir",
type=Path,
help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.",
)
parser.add_argument( parser.add_argument(
"--librimix-tr-split", "--librimix-tr-split",
default="train-360", default="train-360",
...@@ -60,8 +64,8 @@ def cli_main(): ...@@ -60,8 +64,8 @@ def cli_main():
) )
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
default=pathlib.Path("./exp"), default=Path("./exp"),
type=pathlib.Path, type=Path,
help="The directory to save checkpoints and logs." help="The directory to save checkpoints and logs."
) )
parser.add_argument( parser.add_argument(
...@@ -95,7 +99,7 @@ def cli_main(): ...@@ -95,7 +99,7 @@ def cli_main():
args.librimix_tr_split, args.librimix_tr_split,
) )
eval(model, eval_loader, device) _eval(model, eval_loader, device)
if __name__ == "__main__": if __name__ == "__main__":
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# pyre-strict # pyre-strict
from pathlib import Path
import pathlib
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import ( from typing import (
Any, Any,
...@@ -13,6 +12,7 @@ from typing import ( ...@@ -13,6 +12,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
TypedDict, TypedDict,
Union,
) )
import torch import torch
...@@ -279,7 +279,7 @@ def _get_model( ...@@ -279,7 +279,7 @@ def _get_model(
def _get_dataloader( def _get_dataloader(
dataset_type: str, dataset_type: str,
dataset_dir: pathlib.Path, root_dir: Union[str, Path],
num_speakers: int = 2, num_speakers: int = 2,
sample_rate: int = 8000, sample_rate: int = 8000,
batch_size: int = 6, batch_size: int = 6,
...@@ -291,11 +291,11 @@ def _get_dataloader( ...@@ -291,11 +291,11 @@ def _get_dataloader(
Args: Args:
dataset_type (str): the dataset to use. dataset_type (str): the dataset to use.
dataset_dir (pathlib.Path): the root directory of the dataset. root_dir (str or Path): the root directory of the dataset.
num_speakers (int): the number of speakers in the mixture. (Default: 2) num_speakers (int, optional): the number of speakers in the mixture. (Default: 2)
sample_rate (int): the sample rate of the audio. (Default: 8000) sample_rate (int, optional): the sample rate of the audio. (Default: 8000)
batch_size (int): the batch size of the dataset. (Default: 6) batch_size (int, optional): the batch size of the dataset. (Default: 6)
num_workers (int): the number of workers for each dataloader. (Default: 4) num_workers (int, optional): the number of workers for each dataloader. (Default: 4)
librimix_task (str or None, optional): the task in LibriMix dataset. librimix_task (str or None, optional): the task in LibriMix dataset.
librimix_tr_split (str or None, optional): the training split in LibriMix dataset. librimix_tr_split (str or None, optional): the training split in LibriMix dataset.
...@@ -303,7 +303,7 @@ def _get_dataloader( ...@@ -303,7 +303,7 @@ def _get_dataloader(
tuple: (train_loader, valid_loader, eval_loader) tuple: (train_loader, valid_loader, eval_loader)
""" """
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset( train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, dataset_dir, num_speakers, sample_rate, librimix_task, librimix_tr_split dataset_type, root_dir, num_speakers, sample_rate, librimix_task, librimix_tr_split
) )
train_collate_fn = dataset_utils.get_collate_fn( train_collate_fn = dataset_utils.get_collate_fn(
dataset_type, mode='train', sample_rate=sample_rate, duration=3 dataset_type, mode='train', sample_rate=sample_rate, duration=3
...@@ -337,9 +337,13 @@ def _get_dataloader( ...@@ -337,9 +337,13 @@ def _get_dataloader(
def cli_main(): def cli_main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--batch-size", default=3, type=int) parser.add_argument("--batch-size", default=6, type=int)
parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"]) parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"])
parser.add_argument("--data-dir", default=pathlib.Path("./Libri2Mix/wav8k/min"), type=pathlib.Path) parser.add_argument(
"--root-dir",
type=Path,
help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.",
)
parser.add_argument( parser.add_argument(
"--librimix-tr-split", "--librimix-tr-split",
default="train-360", default="train-360",
...@@ -364,8 +368,8 @@ def cli_main(): ...@@ -364,8 +368,8 @@ def cli_main():
) )
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
default=pathlib.Path("./exp"), default=Path("./exp"),
type=pathlib.Path, type=Path,
help="The directory to save checkpoints and logs." help="The directory to save checkpoints and logs."
) )
parser.add_argument( parser.add_argument(
...@@ -404,7 +408,7 @@ def cli_main(): ...@@ -404,7 +408,7 @@ def cli_main():
) )
train_loader, valid_loader, eval_loader = _get_dataloader( train_loader, valid_loader, eval_loader = _get_dataloader(
args.dataset, args.dataset,
args.data_dir, args.root_dir,
args.num_speakers, args.num_speakers,
args.sample_rate, args.sample_rate,
args.batch_size, args.batch_size,
......
from . import utils, wsj0mix, librimix from . import utils, wsj0mix
__all__ = ['utils', 'wsj0mix', 'librimix'] __all__ = ['utils', 'wsj0mix']
...@@ -2,9 +2,10 @@ from typing import List ...@@ -2,9 +2,10 @@ from typing import List
from functools import partial from functools import partial
from collections import namedtuple from collections import namedtuple
from torchaudio.datasets import LibriMix
import torch import torch
from . import wsj0mix, librimix from . import wsj0mix
Batch = namedtuple("Batch", ["mix", "src", "mask"]) Batch = namedtuple("Batch", ["mix", "src", "mask"])
...@@ -15,9 +16,9 @@ def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, li ...@@ -15,9 +16,9 @@ def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, li
validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate) validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate) evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
elif dataset_type == "librimix": elif dataset_type == "librimix":
train = librimix.LibriMix(root_dir / librimix_tr_split, num_speakers, sample_rate, task) train = LibriMix(root_dir, librimix_tr_split, num_speakers, sample_rate, task)
validation = librimix.LibriMix(root_dir / "dev", num_speakers, sample_rate, task) validation = LibriMix(root_dir, "dev", num_speakers, sample_rate, task)
evaluation = librimix.LibriMix(root_dir / "test", num_speakers, sample_rate, task) evaluation = LibriMix(root_dir, "test", num_speakers, sample_rate, task)
else: else:
raise ValueError(f"Unexpected dataset: {dataset_type}") raise ValueError(f"Unexpected dataset: {dataset_type}")
return train, validation, evaluation return train, validation, evaluation
......
...@@ -8,6 +8,7 @@ from .yesno import YESNO ...@@ -8,6 +8,7 @@ from .yesno import YESNO
from .ljspeech import LJSPEECH from .ljspeech import LJSPEECH
from .cmuarctic import CMUARCTIC from .cmuarctic import CMUARCTIC
from .cmudict import CMUDict from .cmudict import CMUDict
from .librimix import LibriMix
from .libritts import LIBRITTS from .libritts import LIBRITTS
from .tedlium import TEDLIUM from .tedlium import TEDLIUM
...@@ -23,6 +24,7 @@ __all__ = [ ...@@ -23,6 +24,7 @@ __all__ = [
"GTZAN", "GTZAN",
"CMUARCTIC", "CMUARCTIC",
"CMUDict", "CMUDict",
"LibriMix",
"LIBRITTS", "LIBRITTS",
"diskcache_iterator", "diskcache_iterator",
"bg_iterator", "bg_iterator",
......
...@@ -13,27 +13,43 @@ class LibriMix(Dataset): ...@@ -13,27 +13,43 @@ class LibriMix(Dataset):
r"""Create the LibriMix dataset. r"""Create the LibriMix dataset.
Args: Args:
root (str or Path): the path to the directory where the dataset is stored. root (str or Path): The path to the directory where the directory ``Libri2Mix`` or
``Libri3Mix`` is stored.
subset (str, optional): The subset to use. Options: [``train-360`, ``train-100``,
``dev``, and ``test``] (Default: ``train-360``).
num_speakers (int, optional): The number of speakers, which determines the directories num_speakers (int, optional): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios. (Default: 2) N source audios. (Default: 2)
sample_rate (int, optional): sample rate of audio files. If any of the audio has a sample_rate (int, optional): sample rate of audio files. The ``sample_rate`` determines
different sample rate, raises ``ValueError``. (Default: 8000) which subdirectory the audio are fetched. If any of the audio has a different sample
rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)
task (str, optional): the task of LibriMix. task (str, optional): the task of LibriMix.
Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``] Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``]
(Default: ``sep_clean``) (Default: ``sep_clean``)
Note:
The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
""" """
def __init__( def __init__(
self, self,
root: Union[str, Path], root: Union[str, Path],
subset: str = "train-360",
num_speakers: int = 2, num_speakers: int = 2,
sample_rate: int = 8000, sample_rate: int = 8000,
task: str = "sep_clean", task: str = "sep_clean",
): ):
self.root = Path(root) self.root = Path(root) / f"Libri{num_speakers}Mix"
if sample_rate == 8000:
self.root = self.root / "wav8k/min" / subset
elif sample_rate == 16000:
self.root = self.root / "wav16k/min" / subset
else:
raise ValueError(
f"Unsupported sample rate. Found {sample_rate}."
)
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.task = task self.task = task
self.mix_dir = (self.root / "mix_{}".format(task.split('_')[1])).resolve() self.mix_dir = (self.root / f"mix_{task.split('_')[1]}").resolve()
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)] self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
self.files = [p.name for p in self.mix_dir.glob("*wav")] self.files = [p.name for p in self.mix_dir.glob("*wav")]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment