utils.py 3.57 KB
Newer Older
1
2
3
4
5
6
from typing import List
from functools import partial
from collections import namedtuple

import torch

7
from . import wsj0mix, librimix
8
9
10
11

Batch = namedtuple("Batch", ["mix", "src", "mask"])


12
def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, librimix_tr_split=None):
13
14
15
16
    if dataset_type == "wsj0mix":
        train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate)
        validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
        evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
17
18
19
20
    elif dataset_type == "librimix":
        train = librimix.LibriMix(root_dir / librimix_tr_split, num_speakers, sample_rate, task)
        validation = librimix.LibriMix(root_dir / "dev", num_speakers, sample_rate, task)
        evaluation = librimix.LibriMix(root_dir / "test", num_speakers, sample_rate, task)
21
22
23
24
25
    else:
        raise ValueError(f"Unexpected dataset: {dataset_type}")
    return train, validation, evaluation


26
def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, sample_rate: int, random_start=False):
27
    """Ensure waveform has exact number of frames by slicing or padding"""
28
29
    mix = sample[1]  # [1, time]
    src = torch.cat(sample[2], 0)  # [num_sources, time]
30
31

    num_channels, num_frames = src.shape
32
33
    num_seconds = torch.div(num_frames, sample_rate, rounding_mode='floor')
    target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode='floor')
34
35
    if num_frames >= target_num_frames:
        if random_start and num_frames > target_num_frames:
36
            start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
            mix = mix[:, start_frame:]
            src = src[:, start_frame:]
        mix = mix[:, :target_num_frames]
        src = src[:, :target_num_frames]
        mask = torch.ones_like(mix)
    else:
        num_padding = target_num_frames - num_frames
        pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device)
        mix = torch.cat([mix, pad], 1)
        src = torch.cat([src, pad.expand(num_channels, -1)], 1)
        mask = torch.ones_like(mix)
        mask[..., num_frames:] = 0
    return mix, src, mask


def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration):
    target_num_frames = int(duration * sample_rate)

    mixes, srcs, masks = [], [], []
    for sample in samples:
57
        mix, src, mask = _fix_num_frames(sample, target_num_frames, sample_rate, random_start=True)
58
59
60
61
62
63
64
65

        mixes.append(mix)
        srcs.append(src)
        masks.append(mask)

    return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))


66
def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType], sample_rate):
67
68
69
70
    max_num_frames = max(s[1].shape[-1] for s in samples)

    mixes, srcs, masks = [], [], []
    for sample in samples:
71
        mix, src, mask = _fix_num_frames(sample, max_num_frames, sample_rate, random_start=False)
72
73
74
75
76
77
78
79
80
81

        mixes.append(mix)
        srcs.append(src)
        masks.append(mask)

    return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))


def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
    assert mode in ["train", "test"]
82
    if dataset_type in ["wsj0mix", "librimix"]:
83
84
85
86
        if mode == 'train':
            if sample_rate is None:
                raise ValueError("sample_rate is not given.")
            return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
87
        return partial(collate_fn_wsj0mix_test, sample_rate=sample_rate)
88
    raise ValueError(f"Unexpected dataset: {dataset_type}")