hubert_dataset.py 17.3 KB
Newer Older
1
import math
2
from pathlib import Path
3
from typing import Dict, Iterator, List, Optional, Tuple, Union
4

5
6
import numpy as np
import torch
7
import torch.distributed as dist
8
import torchaudio
9
from torch import Tensor
10
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
11
12


13
14
class BucketizeBatchSampler(BatchSampler):
    """Buketized BatchSampler for sequential data with different lengths to reduce number of paddings.
15
16

    Args:
17
        lengths (List[int]): The lengths of the samples in the dataset.
18
        num_buckets (int): The number of buckets to split the data samples.
19
20
21
22
        min_len (int, optional): The minimum sample lengths to keep.
            (Default: 0)
        max_len (int or None, optional): The maximum sample lengths to keep. Inferred if not provided.
            (Default ``None``)
23
24
25
        max_token_count (int or None, optional): The max number of tokens in one mini-batch.
            (Default: ``None``)
        batch_size (int or None, optional): The number of samples in one mini-batch.
26
            (Default: ``None``)
27
        shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling.
28
29
30
31
32
33
34
35
            (Default: True)
        drop_last (bool, optional): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
            (Default: False)

    Note:
        ``max_token_count`` and ``batch_size`` are mutually exclusive. Only one argument of the two
        should have value.
36

37
38
    Note:
        ``drop_last`` is only valid when ``batch_size`` argument is given.
39
40
41
42

    Note:
        if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
        in pytorch_lightning Trainer to enable shuffling every epoch.
43
    """
44

45
46
    def __init__(
        self,
47
        lengths: List[int],
48
        num_buckets: int,
49
50
        min_len: int = 0,
        max_len: Optional[int] = None,
51
        max_token_count: Optional[int] = None,
52
        batch_size: Optional[int] = None,
53
        shuffle: bool = True,
54
        drop_last: bool = False,
55
    ) -> None:
56
57
58
59
60
        if max_len is None:
            max_len = max(lengths)

        if not (0 <= min_len <= max_len):
            raise AssertionError("``min_len`` should be non-negative and smaller than ``max_len``")
61
        if max_token_count is not None and batch_size is not None:
62
            raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.")
63
64
65
66
67
68
        if max_token_count is None and batch_size is None:
            raise AssertionError("One of ``max_token_count`` or ``batch_size`` must be set.")
        if max_token_count is not None:
            assert (
                max_len <= max_token_count
            ), "The  ``max_token_count`` must be greater than or equal to the maximum value of ``lengths``."
69
70
71
72
73
74
75
        # Filter out samples which are outside the bounds of [min_len, max_len]
        filtered_length_idx = [(length, i) for i, length in enumerate(lengths) if min_len <= length <= max_len]
        if len(filtered_length_idx) == 0:
            raise AssertionError("``lengths`` cannot be empty after filtering.")
        sorted_filtered_length_idx = sorted(filtered_length_idx, key=lambda x: x[0])
        self.lengths = [e[0] for e in sorted_filtered_length_idx]
        self.indices = [e[1] for e in sorted_filtered_length_idx]
76
77
        self.max_token_count = max_token_count
        self.batch_size = batch_size
78
        self.shuffle = shuffle
79
80
81
        self.drop_last = drop_last
        self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len)
        self._update_iter_list()
82

83
    def _get_buckets(self, lengths: List[int], num_buckets: int, min_len: int, max_len: int) -> Dict[int, Tensor]:
84
85
        """Generate buckets based on the dataset.
        Args:
86
            lengths (List[int]): The lengths of the samples in the dataset.
87
            num_buckets (int): The number of buckets.
88
89
            min_len (int): The lower bound of the evenly spaced length intervals to determine bucket width.
            max_len (int): The upper bound of the evenly spaced length intervals to determine bucket width.
90
91
92
93
94
95

        Returns:
            (dict[int, Tensor]): A dictionary in which the key is the bucket index, the value is
                the Tensor of corresponding sample indices.
        """
        buckets = {}
96
97
98
99
        boundaries = torch.linspace(min_len - 1, max_len + 1, num_buckets + 1)
        bucket_ids = torch.bucketize(torch.tensor(lengths), boundaries)
        for i in range(bucket_ids.size(0)):
            bucket_id = int(bucket_ids[i])
100
101
102
103
104
105
            if bucket_id in buckets:
                buckets[bucket_id].append(i)
            else:
                buckets[bucket_id] = [i]
        for k in buckets:
            buckets[k] = torch.as_tensor(buckets[k], dtype=torch.int)
106
        buckets = {k: v for k, v in sorted(buckets.items())}
107
108
        return buckets

109
    def _update_iter_list(self) -> None:
110
111
112
        if self.shuffle:
            for k in self.buckets:
                self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0))]
113
        self.iter_list = []
114
115
        total_len = 0
        batch = []
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        max_batch_size = self.max_token_count if self.max_token_count else self.batch_size
        for k in self.buckets:
            for i in range(self.buckets[k].size(0)):
                index = int(self.buckets[k][i])
                sample_length = self.lengths[index] if self.max_token_count else 1
                if total_len + sample_length <= max_batch_size:
                    batch.append(self.indices[index])
                    total_len += sample_length
                else:
                    self.iter_list.append(batch)
                    batch = [self.indices[index]]
                    total_len = sample_length
        if len(batch) > 0 and (self.max_token_count or not self.drop_last):
            self.iter_list.append(batch)

    def __iter__(self) -> Iterator[List[int]]:
        return iter(self.iter_list)
133
134

    def __len__(self):
135
136
        if self.batch_size or (self.max_token_count and not self.shuffle):
            return len(self.iter_list)
137
138


139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class DistributedBatchSampler(DistributedSampler):
    """`BucketizeBatchSampler` wrapper that distributes across each processor.

    Args:
        batch_sampler (BucketizeBatchSampler): the initialized bucketize batch sampler.
        num_replicas (int, optional): Number of processes participating in
            distributed training. By default, :attr:`world_size` is retrieved from the
            current distributed group.
        rank (int, optional): Rank of the current process within :attr:`num_replicas`.
            By default, :attr:`rank` is retrieved from the current distributed
            group.
        shuffle (bool, optional): if ``True``, the list of batch indices will be shuffled.
            (Default: ``True``)
        seed (int, optional): random seed used to shuffle the batch_sampler if
            :attr:`shuffle=True`. This number should be identical across all
            processes in the distributed group. (Default: ``0``)
        drop_last (bool, optional): if ``True``, then the sampler will drop the
            tail of the data to make it evenly divisible across the number of
            replicas. If ``False``, the sampler will add extra indices to make
            the data evenly divisible across the replicas. (Default: ``False``)

    Note:
        if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
        in pytorch_lightning Trainer, and set `sampler.set_epoch(self.current_epoch)` before DataLoader initialization
        in `train_dataloader` method to enable shuffling every epoch.
    """

    def __init__(
        self,
        batch_sampler: BucketizeBatchSampler,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
        seed: int = 0,
        drop_last: bool = False,
    ) -> None:
        self.batch_sampler = batch_sampler
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.num_replicas = num_replicas
        self.rank = rank
        self.shuffle = shuffle
        self.epoch = 0
        self.seed = seed
        self.drop_last = drop_last
        if shuffle:
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            perm = torch.randperm(len(self.batch_sampler.iter_list), generator=g).tolist()
            indices = [self.batch_sampler.iter_list[i] for i in perm]
        else:
            indices = self.batch_sampler.iter_list
        if self.drop_last:
            self.total_size = len(indices) - len(indices) % self.num_replicas
        else:
            padding_size = self.num_replicas - len(indices) % self.num_replicas
            indices += indices[:padding_size]
            self.total_size = len(indices)
        self.num_samples = self.total_size // self.num_replicas
        self.subset = indices[self.rank : self.total_size : self.num_replicas]
        assert len(self.subset) == self.num_samples

    def __iter__(self):
        return iter(self.subset)

    def __len__(self):
        return self.num_samples


214
215
216
217
218
219
220
221
class HuBERTDataSet(Dataset):
    """Create a Dataset for HuBERT model training and fine-tuning.

    Args:
        exp_dir (str or Path): The root directory of the ``.tsv`` file list.
        dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
        subset (str): The subset of the dataset. Options: [``train``, ``valid``].
    """
222

223
224
225
226
227
228
229
230
231
    def __init__(
        self,
        exp_dir: Union[str, Path],
        dataset: str,
        subset: str,
    ) -> None:
        self.exp_dir = Path(exp_dir)
        tsv_dir = self.exp_dir / "tsv"
        label_dir = self.exp_dir / "label"
232
        f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset)
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list
        self.labels = self._load_labels(label_dir, dataset, subset)

    def __len__(self):
        return len(self.f_list)

    def _get_lists(
        self,
        tsv_dir: Path,
        dataset: str,
        subset: str,
    ) -> Tuple[List[Path], List[int], List[int]]:
        """Get the list of paths for iteration.
        Args:
            tsv_dir (Path): The root directory of the ``.tsv`` file list.
            dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
            subset (str): The subset of the dataset. Options: [``train``, ``valid``].

        Returns:
            (numpy.array) List of file paths.
253
            (numpy.array) List of indices.
254
255
256
257
258
259
260
261
262
            (numpy.array) List of waveform lengths.
        """
        f_ind_len_list = []
        with open(tsv_dir / f"{dataset}_{subset}.tsv") as f:
            root = f.readline().rstrip()
            for index, line in enumerate(f):
                path, nsample = line.split("\t")
                path = f"{root}/{path}"
                nsample = int(nsample)
263
                f_ind_len_list.append((path, index, nsample))
264
265
266
267
268
269
270
        f_list, ind_list, len_list = [], [], []
        for ele in f_ind_len_list:
            f_list.append(ele[0])
            ind_list.append(ele[1])
            len_list.append(ele[2])
        return np.asarray(f_list), np.asarray(ind_list), np.asarray(len_list)

271
    def _load_audio(self, index: int) -> Tensor:
272
273
274
275
276
277
278
279
280
281
282
283
        """Load waveform given the sample index of the dataset.
        Args:
            index (int): The sample index.

        Returns:
            (Tensor): The corresponding waveform Tensor.
        """
        wav_path = self.f_list[index]
        waveform, sample_rate = torchaudio.load(wav_path)
        assert waveform.shape[1] == self.len_list[index]
        return waveform

284
    def _load_labels(self, label_dir: Path, dataset: str, subset: str) -> np.array:
285
286
287
288
289
290
291
292
293
        """Load all labels to memory into a numpy array.
        Args:
            label_dir (Path): The directory that contains the label file.
            dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
            subset (str): The subset of the dataset. Options: [``train``, ``valid``].

        Returns:
            (np.array): The numpy arrary that contains the labels for each audio file.
        """
294
        with open(label_dir / f"label_{subset}.pt") as f:
295
296
297
298
299
300
301
302
303
304
305
306
            labels = [line.rstrip() for line in f]
            labels = [labels[i] for i in self.ind_list]
        return np.asarray(labels, dtype=np.string_)

    def __getitem__(self, index):
        waveform = self._load_audio(index)
        length = waveform.shape[1]
        label = [int(ele) for ele in self.labels[index].split()]
        label = torch.tensor(label)
        return (waveform, label, length)


307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def _crop_audio_label(
    waveform: Tensor,
    label: Tensor,
    length: Tensor,
    num_frames: int,
    rand_crop: bool,
) -> Tuple[Tensor, Tensor, Tensor]:
    """Collate the audio and label at the same time.
    Args:
        waveform (Tensor): The waveform Tensor with dimensions `(1, time)`.
        label (Tensor): The label Tensor with dimensions `(1, seq)`.
        length (Tensor): The length Tensor with dimension `(1,)`.
        num_frames (int): The final length of the waveform.
        rand_crop (bool): if ``rand_crop`` is True, the starting index of the
            waveform and label is random if the length is longer than the minimum
            length in the mini-batch.

    Returns:
        (Tuple(Tensor, Tensor, Tensor)): Returns the Tensors for the waveform,
            label, and the waveform length.
    """
    kernel_size = 25
    stride = 20
    sample_rate = 16  # 16 per millisecond
    frame_offset = 0
    waveform = waveform[0]
    if waveform.size(0) > num_frames and rand_crop:
        diff = waveform.size(0) - num_frames
        frame_offset = torch.randint(diff, size=(1,))
    elif waveform.size(0) < num_frames:
        num_frames = waveform.size(0)
    label_offset = max(math.floor((frame_offset - kernel_size * sample_rate) / (stride * sample_rate)) + 1, 0)
    num_label = math.floor((num_frames - kernel_size * sample_rate) / (stride * sample_rate)) + 1
    waveform = waveform[frame_offset : frame_offset + num_frames]
    label = label[label_offset : label_offset + num_label]
    length = num_frames

    return waveform, label, length


347
348
349
350
351
class CollateFnHubert:
    """The collate class for HuBERT pre-training and fine-tuning.
    Args:
        feature_type (str): The type of features for KMeans clustering.
            Options: [``mfcc``, ``hubert``].
352
353
        pad (bool): If ``True``, the waveforms and labels will be padded to the
            max length in the mini-batch. If ``pad`` is False, the waveforms
354
355
            and labels will be cropped to the minimum length in the mini-batch.
            (Default: False)
356
357
        rand_crop (bool): if ``True``, the starting index of the waveform
            and label is random if the length is longer than the minimum
358
359
            length in the mini-batch.
    """
360

361
362
363
364
365
366
367
368
369
370
    def __init__(
        self,
        feature_type: str,
        pad: bool = False,
        rand_crop: bool = True,
    ) -> None:
        self.feature_type = feature_type
        self.pad = pad
        self.rand_crop = rand_crop

371
    def __call__(self, batch: List[Tuple[Tensor, Tensor, int]]) -> Tuple[Tensor, Tensor, Tensor]:
372
373
374
375
376
377
378
        """
        Args:
            batch (List[Tuple(Tensor, Tensor, int)]):
                The list of tuples that contains the waveforms, labels, and audio lengths.

        Returns:
            (Tuple(Tensor, Tensor, Tensor)):
379
380
381
                The Tensor of waveforms with dimensions `(batch, time)`.
                The Tensor of labels with dimensions `(batch, seq)`.
                The Tensor of audio lengths with dimension `(batch,)`.
382
383
        """
        if self.pad:
384
            num_frames = max([sample[0].shape[1] for sample in batch])
385
        else:
386
            num_frames = min([sample[0].shape[1] for sample in batch])
387
388
389
        waveforms, labels, lengths = [], [], []
        for sample in batch:
            waveform, label, length = sample
390
391
            # The MFCC feature is 10ms per frame, while the HuBERT's transformer output
            # is 20ms per frame. Downsample the KMeans label if it's generated by MFCC features.
392
393
            if self.feature_type == "mfcc":
                label = label[::2]
394
            waveform, label, length = _crop_audio_label(waveform, label, length, num_frames, self.rand_crop)
395
396
397
            waveforms.append(waveform)
            lengths.append(length)
            labels.append(label)
398
399
400
401
402
403
404
405
406
        # make sure the shapes are the same if not apply zero-padding
        if not self.pad:
            assert all(
                [waveform.shape[0] == waveforms[0].shape[0] for waveform in waveforms]
            ), "The dimensions of the waveforms should be identical in the same batch."
            assert all(
                [label.shape[0] == labels[0].shape[0] for label in labels]
            ), "The dimensions of the labels should be identical in the same batch."
        waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
407
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
408
409
        lengths = torch.tensor(lengths)
        return waveforms, labels, lengths