Commit 720c36b1 authored by Caroline Chen's avatar Caroline Chen
Browse files

Add metadata for Librimix (#2751)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2751

Reviewed By: nateanl

Differential Revision: D40267874

Pulled By: carolineechen

fbshipit-source-id: 4e45a02c650ed65c05cde82289a400a3be877927
parent 9574b7ca
import os
from pathlib import Path
from typing import List, Tuple, Union
import torch
import torchaudio
from torch.utils.data import Dataset
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
from torchaudio.datasets.utils import _load_waveform
_TASKS_TO_MIXTURE = {
"sep_clean": "mix_clean",
......@@ -55,45 +54,62 @@ class LibriMix(Dataset):
if mode not in ["max", "min"]:
raise ValueError(f'Expect ``mode`` to be one in ["min", "max"]. Found {mode}.')
if sample_rate == 8000:
self.root = self.root / "wav8k" / mode / subset
mix_dir = self.root / "wav8k" / mode / subset
elif sample_rate == 16000:
self.root = self.root / "wav16k" / mode / subset
mix_dir = self.root / "wav16k" / mode / subset
else:
raise ValueError(f"Unsupported sample rate. Found {sample_rate}.")
self.sample_rate = sample_rate
self.task = task
self.mix_dir = (self.root / _TASKS_TO_MIXTURE[task]).resolve()
self.mix_dir = mix_dir / _TASKS_TO_MIXTURE[task]
if task == "enh_both":
self.src_dirs = [(self.root / "mix_clean")]
self.src_dirs = [(mix_dir / "mix_clean")]
else:
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
self.src_dirs = [(mix_dir / f"s{i+1}") for i in range(num_speakers)]
self.files = [p.name for p in self.mix_dir.glob("*.wav")]
self.files.sort()
def _load_audio(self, path) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(path)
if sample_rate != self.sample_rate:
raise ValueError(
f"The dataset contains audio file of sample rate {sample_rate}, "
f"but the requested sample rate is {self.sample_rate}."
)
return waveform
def _load_sample(self, filename) -> SampleType:
mixed = self._load_audio(str(self.mix_dir / filename))
def _load_sample(self, key) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
metadata = self.get_metadata(key)
mixed = _load_waveform(self.root, metadata[1], metadata[0])
srcs = []
for i, dir_ in enumerate(self.src_dirs):
src = self._load_audio(str(dir_ / filename))
for i, path_ in enumerate(metadata[2]):
src = _load_waveform(self.root, path_, metadata[0])
if mixed.shape != src.shape:
raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
srcs.append(src)
return self.sample_rate, mixed, srcs
def get_metadata(self, key: int) -> Tuple[int, str, List[str]]:
"""Get metadata for the n-th sample from the dataset.
Args:
key (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
int:
Sample rate
str:
Path to mixed audio
List of str:
List of paths to source audios
"""
filename = self.files[key]
mixed_path = os.path.relpath(self.mix_dir / filename, self.root)
srcs_paths = []
for dir_ in self.src_dirs:
src = os.path.relpath(dir_ / filename, self.root)
srcs_paths.append(src)
return self.sample_rate, mixed_path, srcs_paths
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, key: int) -> SampleType:
def __getitem__(self, key: int) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
"""Load the n-th sample from the dataset.
Args:
......@@ -106,7 +122,7 @@ class LibriMix(Dataset):
Sample rate
Tensor:
Mixture waveform
list of Tensors:
List of Tensors:
List of source waveforms
"""
return self._load_sample(self.files[key])
return self._load_sample(key)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment