Commit dc0990c7 authored by nateanl's avatar nateanl Committed by Zhaoheng Ni
Browse files

[Cherry-picked 0.10] Move LibriMix dataset to datasets directory (#1833)

parent e6fccfda
from argparse import ArgumentParser
import pathlib
from pathlib import Path
from lightning_train import _get_model, _get_dataloader, sisdri_metric
import mir_eval
import torch
def eval(model, data_loader, device):
def _eval(model, data_loader, device):
results = torch.zeros(4)
with torch.no_grad():
for i, batch in enumerate(data_loader):
for _, batch in enumerate(data_loader):
mix, src, mask = batch
mix, src, mask = mix.to(device), src.to(device), mask.to(device)
est = model(mix)
......@@ -35,7 +35,11 @@ def eval(model, data_loader, device):
def cli_main():
parser = ArgumentParser()
parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"])
parser.add_argument("--data-dir", default=pathlib.Path("./Libri2Mix/wav8k/min"), type=pathlib.Path)
parser.add_argument(
"--root-dir",
type=Path,
help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.",
)
parser.add_argument(
"--librimix-tr-split",
default="train-360",
......@@ -60,8 +64,8 @@ def cli_main():
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
default=Path("./exp"),
type=Path,
help="The directory to save checkpoints and logs."
)
parser.add_argument(
......@@ -95,7 +99,7 @@ def cli_main():
args.librimix_tr_split,
)
eval(model, eval_loader, device)
_eval(model, eval_loader, device)
if __name__ == "__main__":
......
#!/usr/bin/env python3
# pyre-strict
import pathlib
from pathlib import Path
from argparse import ArgumentParser
from typing import (
Any,
......@@ -13,6 +12,7 @@ from typing import (
Optional,
Tuple,
TypedDict,
Union,
)
import torch
......@@ -279,7 +279,7 @@ def _get_model(
def _get_dataloader(
dataset_type: str,
dataset_dir: pathlib.Path,
root_dir: Union[str, Path],
num_speakers: int = 2,
sample_rate: int = 8000,
batch_size: int = 6,
......@@ -291,11 +291,11 @@ def _get_dataloader(
Args:
dataset_type (str): the dataset to use.
dataset_dir (pathlib.Path): the root directory of the dataset.
num_speakers (int): the number of speakers in the mixture. (Default: 2)
sample_rate (int): the sample rate of the audio. (Default: 8000)
batch_size (int): the batch size of the dataset. (Default: 6)
num_workers (int): the number of workers for each dataloader. (Default: 4)
root_dir (str or Path): the root directory of the dataset.
num_speakers (int, optional): the number of speakers in the mixture. (Default: 2)
sample_rate (int, optional): the sample rate of the audio. (Default: 8000)
batch_size (int, optional): the batch size of the dataset. (Default: 6)
num_workers (int, optional): the number of workers for each dataloader. (Default: 4)
librimix_task (str or None, optional): the task in LibriMix dataset.
librimix_tr_split (str or None, optional): the training split in LibriMix dataset.
......@@ -303,7 +303,7 @@ def _get_dataloader(
tuple: (train_loader, valid_loader, eval_loader)
"""
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, dataset_dir, num_speakers, sample_rate, librimix_task, librimix_tr_split
dataset_type, root_dir, num_speakers, sample_rate, librimix_task, librimix_tr_split
)
train_collate_fn = dataset_utils.get_collate_fn(
dataset_type, mode='train', sample_rate=sample_rate, duration=3
......@@ -337,9 +337,13 @@ def _get_dataloader(
def cli_main():
parser = ArgumentParser()
parser.add_argument("--batch-size", default=3, type=int)
parser.add_argument("--batch-size", default=6, type=int)
parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"])
parser.add_argument("--data-dir", default=pathlib.Path("./Libri2Mix/wav8k/min"), type=pathlib.Path)
parser.add_argument(
"--root-dir",
type=Path,
help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.",
)
parser.add_argument(
"--librimix-tr-split",
default="train-360",
......@@ -364,8 +368,8 @@ def cli_main():
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
default=Path("./exp"),
type=Path,
help="The directory to save checkpoints and logs."
)
parser.add_argument(
......@@ -404,7 +408,7 @@ def cli_main():
)
train_loader, valid_loader, eval_loader = _get_dataloader(
args.dataset,
args.data_dir,
args.root_dir,
args.num_speakers,
args.sample_rate,
args.batch_size,
......
from . import utils, wsj0mix, librimix
from . import utils, wsj0mix
__all__ = ['utils', 'wsj0mix', 'librimix']
__all__ = ['utils', 'wsj0mix']
......@@ -2,9 +2,10 @@ from typing import List
from functools import partial
from collections import namedtuple
from torchaudio.datasets import LibriMix
import torch
from . import wsj0mix, librimix
from . import wsj0mix
Batch = namedtuple("Batch", ["mix", "src", "mask"])
......@@ -15,9 +16,9 @@ def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, li
validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
elif dataset_type == "librimix":
train = librimix.LibriMix(root_dir / librimix_tr_split, num_speakers, sample_rate, task)
validation = librimix.LibriMix(root_dir / "dev", num_speakers, sample_rate, task)
evaluation = librimix.LibriMix(root_dir / "test", num_speakers, sample_rate, task)
train = LibriMix(root_dir, librimix_tr_split, num_speakers, sample_rate, task)
validation = LibriMix(root_dir, "dev", num_speakers, sample_rate, task)
evaluation = LibriMix(root_dir, "test", num_speakers, sample_rate, task)
else:
raise ValueError(f"Unexpected dataset: {dataset_type}")
return train, validation, evaluation
......
......@@ -8,6 +8,7 @@ from .yesno import YESNO
from .ljspeech import LJSPEECH
from .cmuarctic import CMUARCTIC
from .cmudict import CMUDict
from .librimix import LibriMix
from .libritts import LIBRITTS
from .tedlium import TEDLIUM
......@@ -23,6 +24,7 @@ __all__ = [
"GTZAN",
"CMUARCTIC",
"CMUDict",
"LibriMix",
"LIBRITTS",
"diskcache_iterator",
"bg_iterator",
......
......@@ -13,27 +13,43 @@ class LibriMix(Dataset):
r"""Create the LibriMix dataset.
Args:
root (str or Path): the path to the directory where the dataset is stored.
root (str or Path): The path to the directory where the directory ``Libri2Mix`` or
``Libri3Mix`` is stored.
subset (str, optional): The subset to use. Options: [``train-360`, ``train-100``,
``dev``, and ``test``] (Default: ``train-360``).
num_speakers (int, optional): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios. (Default: 2)
sample_rate (int, optional): sample rate of audio files. If any of the audio has a
different sample rate, raises ``ValueError``. (Default: 8000)
sample_rate (int, optional): sample rate of audio files. The ``sample_rate`` determines
which subdirectory the audio are fetched. If any of the audio has a different sample
rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)
task (str, optional): the task of LibriMix.
Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``]
(Default: ``sep_clean``)
Note:
The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
"""
def __init__(
self,
root: Union[str, Path],
subset: str = "train-360",
num_speakers: int = 2,
sample_rate: int = 8000,
task: str = "sep_clean",
):
self.root = Path(root)
self.root = Path(root) / f"Libri{num_speakers}Mix"
if sample_rate == 8000:
self.root = self.root / "wav8k/min" / subset
elif sample_rate == 16000:
self.root = self.root / "wav16k/min" / subset
else:
raise ValueError(
f"Unsupported sample rate. Found {sample_rate}."
)
self.sample_rate = sample_rate
self.task = task
self.mix_dir = (self.root / "mix_{}".format(task.split('_')[1])).resolve()
self.mix_dir = (self.root / f"mix_{task.split('_')[1]}").resolve()
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
self.files = [p.name for p in self.mix_dir.glob("*wav")]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment