Unverified Commit 2d879132 authored by moto's avatar moto Committed by GitHub
Browse files

Add wsj0-mix dataset to source separation example (#895)

parent ba7b7a2f
from . import ( from . import (
dataset,
metrics, metrics,
) )
from typing import List
from functools import partial
from collections import namedtuple
import torch
from . import wsj0mix
Batch = namedtuple("Batch", ["mix", "src", "mask"])
def get_dataset(dataset_type, root_dir, num_speakers, sample_rate):
if dataset_type == "wsj0mix":
train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate)
validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
else:
raise ValueError(f"Unexpected dataset: {dataset_type}")
return train, validation, evaluation
def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, random_start=False):
"""Ensure waveform has exact number of frames by slicing or padding"""
mix = sample[1] # [1, num_frames]
src = torch.cat(sample[2], 0) # [num_sources, num_frames]
num_channels, num_frames = src.shape
if num_frames >= target_num_frames:
if random_start and num_frames > target_num_frames:
start_frame = torch.randint(num_frames - target_num_frames, [1])
mix = mix[:, start_frame:]
src = src[:, start_frame:]
mix = mix[:, :target_num_frames]
src = src[:, :target_num_frames]
mask = torch.ones_like(mix)
else:
num_padding = target_num_frames - num_frames
pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device)
mix = torch.cat([mix, pad], 1)
src = torch.cat([src, pad.expand(num_channels, -1)], 1)
mask = torch.ones_like(mix)
mask[..., num_frames:] = 0
return mix, src, mask
def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration):
target_num_frames = int(duration * sample_rate)
mixes, srcs, masks = [], [], []
for sample in samples:
mix, src, mask = _fix_num_frames(sample, target_num_frames, random_start=True)
mixes.append(mix)
srcs.append(src)
masks.append(mask)
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))
def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType]):
max_num_frames = max(s[1].shape[-1] for s in samples)
mixes, srcs, masks = [], [], []
for sample in samples:
mix, src, mask = _fix_num_frames(sample, max_num_frames, random_start=False)
mixes.append(mix)
srcs.append(src)
masks.append(mask)
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))
def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
assert mode in ["train", "test"]
if dataset_type == "wsj0mix":
if mode == 'train':
if sample_rate is None:
raise ValueError("sample_rate is not given.")
return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
return collate_fn_wsj0mix_test
raise ValueError(f"Unexpected dataset: {dataset_type}")
from pathlib import Path
from typing import Union, Tuple, List
import torch
from torch.utils.data import Dataset
import torchaudio
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
class WSJ0Mix(Dataset):
"""Create a Dataset for wsj0-mix.
Args:
root (str or Path): Path to the directory where the dataset is found.
num_speakers (int): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios.
sample_rate (int): Expected sample rate of audio files. If any of the audio has a
different sample rate, raises ``ValueError``.
audio_ext (str): The extension of audio files to find. (default: ".wav")
"""
def __init__(
self,
root: Union[str, Path],
num_speakers: int,
sample_rate: int,
audio_ext: str = ".wav",
):
self.root = Path(root)
self.sample_rate = sample_rate
self.mix_dir = (self.root / "mix").resolve()
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
self.files = [p.name for p in self.mix_dir.glob(f"*{audio_ext}")]
self.files.sort()
def _load_audio(self, path) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(path)
if sample_rate != self.sample_rate:
raise ValueError(
f"The dataset contains audio file of sample rate {sample_rate}. "
"Where the requested sample rate is {self.sample_rate}."
)
return waveform
def _load_sample(self, filename) -> SampleType:
mixed = self._load_audio(str(self.mix_dir / filename))
srcs = []
for i, dir_ in enumerate(self.src_dirs):
src = self._load_audio(str(dir_ / filename))
if mixed.shape != src.shape:
raise ValueError(
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
)
srcs.append(src)
return self.sample_rate, mixed, srcs
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, key: int) -> SampleType:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)``
"""
return self._load_sample(self.files[key])
import os
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_whitenoise,
save_wav,
normalize_wav,
)
from source_separation.utils.dataset import wsj0mix
_FILENAMES = [
"012c0207_1.9952_01cc0202_-1.9952.wav",
"01co0302_1.63_014c020q_-1.63.wav",
"01do0316_0.24011_205a0104_-0.24011.wav",
"01lc020x_1.1301_027o030r_-1.1301.wav",
"01mc0202_0.34056_205o0106_-0.34056.wav",
"01nc020t_0.53821_018o030w_-0.53821.wav",
"01po030f_2.2136_40ko031a_-2.2136.wav",
"01ra010o_2.4098_403a010f_-2.4098.wav",
"01xo030b_0.22377_016o031a_-0.22377.wav",
"02ac020x_0.68566_01ec020b_-0.68566.wav",
"20co010m_0.82801_019c0212_-0.82801.wav",
"20da010u_1.2483_017c0211_-1.2483.wav",
"20oo010d_1.0631_01ic020s_-1.0631.wav",
"20sc0107_2.0222_20fo010h_-2.0222.wav",
"20tc010f_0.051456_404a0110_-0.051456.wav",
"407c0214_1.1712_02ca0113_-1.1712.wav",
"40ao030w_2.4697_20vc010a_-2.4697.wav",
"40pa0101_1.1087_40ea0107_-1.1087.wav",
]
def _mock_dataset(root_dir, num_speaker):
dirnames = ["mix"] + [f"s{i+1}" for i in range(num_speaker)]
for dirname in dirnames:
os.makedirs(os.path.join(root_dir, dirname), exist_ok=True)
seed = 0
sample_rate = 8000
expected = []
for filename in _FILENAMES:
mix = None
src = []
for dirname in dirnames:
waveform = get_whitenoise(
sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed
)
seed += 1
path = os.path.join(root_dir, dirname, filename)
save_wav(path, waveform, sample_rate)
waveform = normalize_wav(waveform)
if dirname == "mix":
mix = waveform
else:
src.append(waveform)
expected.append((sample_rate, mix, src))
return expected
class TestWSJ0Mix2(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
expected = None
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.expected = _mock_dataset(cls.root_dir, 2)
def test_wsj0mix(self):
dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=2, sample_rate=8000)
n_ite = 0
for i, sample in enumerate(dataset):
(_, sample_mix, sample_src) = sample
(_, expected_mix, expected_src) = self.expected[i]
self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8)
self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8)
self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8)
n_ite += 1
assert n_ite == len(self.expected)
class TestWSJ0Mix3(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
expected = None
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.expected = _mock_dataset(cls.root_dir, 3)
def test_wsj0mix(self):
dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=3, sample_rate=8000)
n_ite = 0
for i, sample in enumerate(dataset):
(_, sample_mix, sample_src) = sample
(_, expected_mix, expected_src) = self.expected[i]
self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8)
self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8)
self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8)
self.assertEqual(sample_src[2], expected_src[2], atol=5e-5, rtol=1e-8)
n_ite += 1
assert n_ite == len(self.expected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment