datasets.py 3.66 KB
Newer Older
1
2
3
import random

import torch
4
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
5
from torch.utils.data.dataset import random_split
6
from torchaudio.datasets import LIBRITTS, LJSPEECH
7
8
9
10
from torchaudio.transforms import MuLawEncoding


class MapMemoryCache(torch.utils.data.Dataset):
11
    r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory."""
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    def __init__(self, dataset):
        self.dataset = dataset
        self._cache = [None] * len(dataset)

    def __getitem__(self, n):
        if self._cache[n] is not None:
            return self._cache[n]

        item = self.dataset[n]
        self._cache[n] = item

        return item

    def __len__(self):
        return len(self.dataset)


class Processed(torch.utils.data.Dataset):
    def __init__(self, dataset, transforms):
        self.dataset = dataset
        self.transforms = transforms

    def __getitem__(self, key):
        item = self.dataset[key]
        return self.process_datapoint(item)

    def __len__(self):
        return len(self.dataset)

    def process_datapoint(self, item):
        specgram = self.transforms(item[0])
        return item[0].squeeze(0), specgram


jimchen90's avatar
jimchen90 committed
47
def split_process_dataset(args, transforms):
48
    if args.dataset == "ljspeech":
jimchen90's avatar
jimchen90 committed
49
        data = LJSPEECH(root=args.file_path, download=False)
50

jimchen90's avatar
jimchen90 committed
51
52
53
54
        val_length = int(len(data) * args.val_ratio)
        lengths = [len(data) - val_length, val_length]
        train_dataset, val_dataset = random_split(data, lengths)

55
56
57
    elif args.dataset == "libritts":
        train_dataset = LIBRITTS(root=args.file_path, url="train-clean-100", download=False)
        val_dataset = LIBRITTS(root=args.file_path, url="dev-clean", download=False)
jimchen90's avatar
jimchen90 committed
58
59
60

    else:
        raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}")
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    train_dataset = Processed(train_dataset, transforms)
    val_dataset = Processed(val_dataset, transforms)

    train_dataset = MapMemoryCache(train_dataset)
    val_dataset = MapMemoryCache(val_dataset)

    return train_dataset, val_dataset


def collate_factory(args):
    def raw_collate(batch):

        pad = (args.kernel_size - 1) // 2

        # input waveform length
        wave_length = args.hop_length * args.seq_len_factor
        # input spectrogram length
        spec_length = args.seq_len_factor + pad * 2

        # max start postion in spectrogram
        max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch]

        # random start postion in spectrogram
        spec_offsets = [random.randint(0, offset) for offset in max_offsets]
        # random start postion in waveform
        wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets]

89
90
        waveform_combine = [x[0][wave_offsets[i] : wave_offsets[i] + wave_length + 1] for i, x in enumerate(batch)]
        specgram = [x[1][:, spec_offsets[i] : spec_offsets[i] + spec_length] for i, x in enumerate(batch)]
91
92
93
94
95
96
97
98
99
100
101

        specgram = torch.stack(specgram)
        waveform_combine = torch.stack(waveform_combine)

        waveform = waveform_combine[:, :wave_length]
        target = waveform_combine[:, 1:]

        # waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy'
        if args.loss == "crossentropy":

            if args.mulaw:
102
                mulaw_encode = MuLawEncoding(2**args.n_bits)
103
104
105
106
107
108
109
110
111
112
113
                waveform = mulaw_encode(waveform)
                target = mulaw_encode(target)

                waveform = bits_to_normalized_waveform(waveform, args.n_bits)

            else:
                target = normalized_waveform_to_bits(target, args.n_bits)

        return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1)

    return raw_collate