utils.py 3.57 KB
Newer Older
1
from collections import namedtuple
2
3
from functools import partial
from typing import List
4
5

import torch
6
from torchaudio.datasets import LibriMix
7

8
from . import wsj0mix
9
10
11
12

Batch = namedtuple("Batch", ["mix", "src", "mask"])


13
def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, librimix_tr_split=None):
14
15
16
17
    if dataset_type == "wsj0mix":
        train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate)
        validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
        evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
18
    elif dataset_type == "librimix":
19
20
21
        train = LibriMix(root_dir, librimix_tr_split, num_speakers, sample_rate, task)
        validation = LibriMix(root_dir, "dev", num_speakers, sample_rate, task)
        evaluation = LibriMix(root_dir, "test", num_speakers, sample_rate, task)
22
23
24
25
26
    else:
        raise ValueError(f"Unexpected dataset: {dataset_type}")
    return train, validation, evaluation


27
def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, sample_rate: int, random_start=False):
28
    """Ensure waveform has exact number of frames by slicing or padding"""
29
30
    mix = sample[1]  # [1, time]
    src = torch.cat(sample[2], 0)  # [num_sources, time]
31
32

    num_channels, num_frames = src.shape
33
34
    num_seconds = torch.div(num_frames, sample_rate, rounding_mode="floor")
    target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode="floor")
35
36
    if num_frames >= target_num_frames:
        if random_start and num_frames > target_num_frames:
37
            start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
            mix = mix[:, start_frame:]
            src = src[:, start_frame:]
        mix = mix[:, :target_num_frames]
        src = src[:, :target_num_frames]
        mask = torch.ones_like(mix)
    else:
        num_padding = target_num_frames - num_frames
        pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device)
        mix = torch.cat([mix, pad], 1)
        src = torch.cat([src, pad.expand(num_channels, -1)], 1)
        mask = torch.ones_like(mix)
        mask[..., num_frames:] = 0
    return mix, src, mask


def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration):
    target_num_frames = int(duration * sample_rate)

    mixes, srcs, masks = [], [], []
    for sample in samples:
58
        mix, src, mask = _fix_num_frames(sample, target_num_frames, sample_rate, random_start=True)
59
60
61
62
63
64
65
66

        mixes.append(mix)
        srcs.append(src)
        masks.append(mask)

    return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))


67
def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType], sample_rate):
68
69
70
71
    max_num_frames = max(s[1].shape[-1] for s in samples)

    mixes, srcs, masks = [], [], []
    for sample in samples:
72
        mix, src, mask = _fix_num_frames(sample, max_num_frames, sample_rate, random_start=False)
73
74
75
76
77
78
79
80
81
82

        mixes.append(mix)
        srcs.append(src)
        masks.append(mask)

    return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))


def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
    assert mode in ["train", "test"]
83
    if dataset_type in ["wsj0mix", "librimix"]:
84
        if mode == "train":
85
86
87
            if sample_rate is None:
                raise ValueError("sample_rate is not given.")
            return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
88
        return partial(collate_fn_wsj0mix_test, sample_rate=sample_rate)
89
    raise ValueError(f"Unexpected dataset: {dataset_type}")