hubert_dataset.py 21.2 KB
Newer Older
1
import math
2
3
4
import os

import sys
5
from pathlib import Path
6
from typing import Dict, Iterator, List, Optional, Tuple, Union
7

8
9
import numpy as np
import torch
10
import torch.distributed as dist
11
import torchaudio
12
from torch import Tensor
13
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
14

15
16
17
sys.path.append("..")
from utils import _get_label2id

18

19
20
class BucketizeBatchSampler(BatchSampler):
    """Buketized BatchSampler for sequential data with different lengths to reduce number of paddings.
21
22

    Args:
23
        lengths (List[int]): The lengths of the samples in the dataset.
24
        num_buckets (int): The number of buckets to split the data samples.
25
26
27
28
        min_len (int, optional): The minimum sample lengths to keep.
            (Default: 0)
        max_len (int or None, optional): The maximum sample lengths to keep. Inferred if not provided.
            (Default ``None``)
29
30
31
        max_token_count (int or None, optional): The max number of tokens in one mini-batch.
            (Default: ``None``)
        batch_size (int or None, optional): The number of samples in one mini-batch.
32
            (Default: ``None``)
33
        shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling.
34
            (Default: True)
35
        seed (int, optional): The seed for initialzing RNG. Only used when `shuffle` is True. (Default: 0)
36
37
38
39
40
41
42
        drop_last (bool, optional): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
            (Default: False)

    Note:
        ``max_token_count`` and ``batch_size`` are mutually exclusive. Only one argument of the two
        should have value.
43

44
45
    Note:
        ``drop_last`` is only valid when ``batch_size`` argument is given.
46
47
48

    Note:
        if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
49
        in pytorch_lightning Trainer and set ``seed`` to ``self.trainer.current_epoch`` to enable shuffling every epoch.
50
    """
51

52
53
    def __init__(
        self,
54
        lengths: List[int],
55
        num_buckets: int,
56
57
        min_len: int = 0,
        max_len: Optional[int] = None,
58
        max_token_count: Optional[int] = None,
59
        batch_size: Optional[int] = None,
60
        shuffle: bool = True,
61
        seed: int = 0,
62
        drop_last: bool = False,
63
    ) -> None:
64
65
66
67
68
        if max_len is None:
            max_len = max(lengths)

        if not (0 <= min_len <= max_len):
            raise AssertionError("``min_len`` should be non-negative and smaller than ``max_len``")
69
        if max_token_count is not None and batch_size is not None:
70
            raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.")
71
72
73
74
75
76
        if max_token_count is None and batch_size is None:
            raise AssertionError("One of ``max_token_count`` or ``batch_size`` must be set.")
        if max_token_count is not None:
            assert (
                max_len <= max_token_count
            ), "The  ``max_token_count`` must be greater than or equal to the maximum value of ``lengths``."
77
78
79
80
81
82
83
        # Filter out samples which are outside the bounds of [min_len, max_len]
        filtered_length_idx = [(length, i) for i, length in enumerate(lengths) if min_len <= length <= max_len]
        if len(filtered_length_idx) == 0:
            raise AssertionError("``lengths`` cannot be empty after filtering.")
        sorted_filtered_length_idx = sorted(filtered_length_idx, key=lambda x: x[0])
        self.lengths = [e[0] for e in sorted_filtered_length_idx]
        self.indices = [e[1] for e in sorted_filtered_length_idx]
84
85
        self.max_token_count = max_token_count
        self.batch_size = batch_size
86
        self.shuffle = shuffle
87
88
89
90
        self.seed = seed
        if self.shuffle:
            self.g = torch.Generator()
            self.g.manual_seed(self.seed)
91
92
93
        self.drop_last = drop_last
        self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len)
        self._update_iter_list()
94

95
    def _get_buckets(self, lengths: List[int], num_buckets: int, min_len: int, max_len: int) -> Dict[int, Tensor]:
96
97
        """Generate buckets based on the dataset.
        Args:
98
            lengths (List[int]): The lengths of the samples in the dataset.
99
            num_buckets (int): The number of buckets.
100
101
            min_len (int): The lower bound of the evenly spaced length intervals to determine bucket width.
            max_len (int): The upper bound of the evenly spaced length intervals to determine bucket width.
102
103
104
105
106
107

        Returns:
            (dict[int, Tensor]): A dictionary in which the key is the bucket index, the value is
                the Tensor of corresponding sample indices.
        """
        buckets = {}
108
109
110
111
        boundaries = torch.linspace(min_len - 1, max_len + 1, num_buckets + 1)
        bucket_ids = torch.bucketize(torch.tensor(lengths), boundaries)
        for i in range(bucket_ids.size(0)):
            bucket_id = int(bucket_ids[i])
112
113
114
115
116
117
            if bucket_id in buckets:
                buckets[bucket_id].append(i)
            else:
                buckets[bucket_id] = [i]
        for k in buckets:
            buckets[k] = torch.as_tensor(buckets[k], dtype=torch.int)
118
        buckets = {k: v for k, v in sorted(buckets.items())}
119
120
        return buckets

121
    def _update_iter_list(self) -> None:
122
123
        if self.shuffle:
            for k in self.buckets:
124
                self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0), generator=self.g)]
125
        self.iter_list = []
126
127
        total_len = 0
        batch = []
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        max_batch_size = self.max_token_count if self.max_token_count else self.batch_size
        for k in self.buckets:
            for i in range(self.buckets[k].size(0)):
                index = int(self.buckets[k][i])
                sample_length = self.lengths[index] if self.max_token_count else 1
                if total_len + sample_length <= max_batch_size:
                    batch.append(self.indices[index])
                    total_len += sample_length
                else:
                    self.iter_list.append(batch)
                    batch = [self.indices[index]]
                    total_len = sample_length
        if len(batch) > 0 and (self.max_token_count or not self.drop_last):
            self.iter_list.append(batch)

    def __iter__(self) -> Iterator[List[int]]:
        return iter(self.iter_list)
145
146

    def __len__(self):
147
148
        if self.batch_size or (self.max_token_count and not self.shuffle):
            return len(self.iter_list)
149
150


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class DistributedBatchSampler(DistributedSampler):
    """`BucketizeBatchSampler` wrapper that distributes across each processor.

    Args:
        batch_sampler (BucketizeBatchSampler): the initialized bucketize batch sampler.
        num_replicas (int, optional): Number of processes participating in
            distributed training. By default, :attr:`world_size` is retrieved from the
            current distributed group.
        rank (int, optional): Rank of the current process within :attr:`num_replicas`.
            By default, :attr:`rank` is retrieved from the current distributed
            group.
        shuffle (bool, optional): if ``True``, the list of batch indices will be shuffled.
            (Default: ``True``)
        seed (int, optional): random seed used to shuffle the batch_sampler if
            :attr:`shuffle=True`. This number should be identical across all
            processes in the distributed group. (Default: ``0``)
        drop_last (bool, optional): if ``True``, then the sampler will drop the
            tail of the data to make it evenly divisible across the number of
            replicas. If ``False``, the sampler will add extra indices to make
            the data evenly divisible across the replicas. (Default: ``False``)

    Note:
        if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
        in pytorch_lightning Trainer, and set `sampler.set_epoch(self.current_epoch)` before DataLoader initialization
        in `train_dataloader` method to enable shuffling every epoch.
    """

    def __init__(
        self,
        batch_sampler: BucketizeBatchSampler,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
        seed: int = 0,
        drop_last: bool = False,
    ) -> None:
        self.batch_sampler = batch_sampler
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.num_replicas = num_replicas
        self.rank = rank
        self.shuffle = shuffle
        self.epoch = 0
        self.seed = seed
        self.drop_last = drop_last
202
203
204
205
206
207
208
209
210
211
212
213
        self.shuffle = shuffle
        indices = self.batch_sampler.iter_list
        if self.drop_last and len(indices) % self.num_replicas != 0:
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil((len(indices) - self.num_replicas) / self.num_replicas)
        else:
            self.num_samples = math.ceil(len(indices) / self.num_replicas)

    def __iter__(self):
        if self.shuffle:
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            perm = torch.randperm(len(self.batch_sampler.iter_list), generator=g).tolist()
            indices = [self.batch_sampler.iter_list[i] for i in perm]
        else:
            indices = self.batch_sampler.iter_list
        if self.drop_last:
            self.total_size = len(indices) - len(indices) % self.num_replicas
        else:
            padding_size = self.num_replicas - len(indices) % self.num_replicas
            indices += indices[:padding_size]
            self.total_size = len(indices)
        self.num_samples = self.total_size // self.num_replicas
        self.subset = indices[self.rank : self.total_size : self.num_replicas]
        assert len(self.subset) == self.num_samples

        return iter(self.subset)

    def __len__(self):
        return self.num_samples


236
237
238
239
240
241
242
243
class HuBERTDataSet(Dataset):
    """Create a Dataset for HuBERT model training and fine-tuning.

    Args:
        exp_dir (str or Path): The root directory of the ``.tsv`` file list.
        dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
        subset (str): The subset of the dataset. Options: [``train``, ``valid``].
    """
244

245
246
247
248
249
250
251
252
253
    def __init__(
        self,
        exp_dir: Union[str, Path],
        dataset: str,
        subset: str,
    ) -> None:
        self.exp_dir = Path(exp_dir)
        tsv_dir = self.exp_dir / "tsv"
        label_dir = self.exp_dir / "label"
254
        f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list
        self.labels = self._load_labels(label_dir, dataset, subset)

    def __len__(self):
        return len(self.f_list)

    def _get_lists(
        self,
        tsv_dir: Path,
        dataset: str,
        subset: str,
    ) -> Tuple[List[Path], List[int], List[int]]:
        """Get the list of paths for iteration.
        Args:
            tsv_dir (Path): The root directory of the ``.tsv`` file list.
            dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
            subset (str): The subset of the dataset. Options: [``train``, ``valid``].

        Returns:
            (numpy.array) List of file paths.
275
            (numpy.array) List of indices.
276
277
278
279
280
281
282
283
284
            (numpy.array) List of waveform lengths.
        """
        f_ind_len_list = []
        with open(tsv_dir / f"{dataset}_{subset}.tsv") as f:
            root = f.readline().rstrip()
            for index, line in enumerate(f):
                path, nsample = line.split("\t")
                path = f"{root}/{path}"
                nsample = int(nsample)
285
                f_ind_len_list.append((path, index, nsample))
286
287
288
289
290
291
292
        f_list, ind_list, len_list = [], [], []
        for ele in f_ind_len_list:
            f_list.append(ele[0])
            ind_list.append(ele[1])
            len_list.append(ele[2])
        return np.asarray(f_list), np.asarray(ind_list), np.asarray(len_list)

293
    def _load_audio(self, index: int) -> Tensor:
294
295
296
297
298
299
300
301
302
303
304
305
        """Load waveform given the sample index of the dataset.
        Args:
            index (int): The sample index.

        Returns:
            (Tensor): The corresponding waveform Tensor.
        """
        wav_path = self.f_list[index]
        waveform, sample_rate = torchaudio.load(wav_path)
        assert waveform.shape[1] == self.len_list[index]
        return waveform

306
    def _load_labels(self, label_dir: Path, dataset: str, subset: str) -> np.array:
307
308
309
310
311
312
313
314
315
        """Load all labels to memory into a numpy array.
        Args:
            label_dir (Path): The directory that contains the label file.
            dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
            subset (str): The subset of the dataset. Options: [``train``, ``valid``].

        Returns:
            (np.array): The numpy arrary that contains the labels for each audio file.
        """
316
        with open(label_dir / f"label_{subset}.pt") as f:
317
318
319
320
321
322
323
324
325
326
327
328
            labels = [line.rstrip() for line in f]
            labels = [labels[i] for i in self.ind_list]
        return np.asarray(labels, dtype=np.string_)

    def __getitem__(self, index):
        waveform = self._load_audio(index)
        length = waveform.shape[1]
        label = [int(ele) for ele in self.labels[index].split()]
        label = torch.tensor(label)
        return (waveform, label, length)


329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def _crop_audio_label(
    waveform: Tensor,
    label: Tensor,
    length: Tensor,
    num_frames: int,
    rand_crop: bool,
) -> Tuple[Tensor, Tensor, Tensor]:
    """Collate the audio and label at the same time.
    Args:
        waveform (Tensor): The waveform Tensor with dimensions `(1, time)`.
        label (Tensor): The label Tensor with dimensions `(1, seq)`.
        length (Tensor): The length Tensor with dimension `(1,)`.
        num_frames (int): The final length of the waveform.
        rand_crop (bool): if ``rand_crop`` is True, the starting index of the
            waveform and label is random if the length is longer than the minimum
            length in the mini-batch.

    Returns:
        (Tuple(Tensor, Tensor, Tensor)): Returns the Tensors for the waveform,
            label, and the waveform length.
    """
    kernel_size = 25
    stride = 20
    sample_rate = 16  # 16 per millisecond
    frame_offset = 0
    waveform = waveform[0]
    if waveform.size(0) > num_frames and rand_crop:
        diff = waveform.size(0) - num_frames
        frame_offset = torch.randint(diff, size=(1,))
    elif waveform.size(0) < num_frames:
        num_frames = waveform.size(0)
    label_offset = max(math.floor((frame_offset - kernel_size * sample_rate) / (stride * sample_rate)) + 1, 0)
    num_label = math.floor((num_frames - kernel_size * sample_rate) / (stride * sample_rate)) + 1
    waveform = waveform[frame_offset : frame_offset + num_frames]
    label = label[label_offset : label_offset + num_label]
    length = num_frames

    return waveform, label, length


369
370
371
372
373
class CollateFnHubert:
    """The collate class for HuBERT pre-training and fine-tuning.
    Args:
        feature_type (str): The type of features for KMeans clustering.
            Options: [``mfcc``, ``hubert``].
374
375
        pad (bool): If ``True``, the waveforms and labels will be padded to the
            max length in the mini-batch. If ``pad`` is False, the waveforms
376
377
            and labels will be cropped to the minimum length in the mini-batch.
            (Default: False)
378
379
        rand_crop (bool): if ``True``, the starting index of the waveform
            and label is random if the length is longer than the minimum
380
381
            length in the mini-batch.
    """
382

383
384
385
386
387
388
389
390
391
392
    def __init__(
        self,
        feature_type: str,
        pad: bool = False,
        rand_crop: bool = True,
    ) -> None:
        self.feature_type = feature_type
        self.pad = pad
        self.rand_crop = rand_crop

393
    def __call__(self, batch: List[Tuple[Tensor, Tensor, int]]) -> Tuple[Tensor, Tensor, Tensor]:
394
395
396
397
398
399
400
        """
        Args:
            batch (List[Tuple(Tensor, Tensor, int)]):
                The list of tuples that contains the waveforms, labels, and audio lengths.

        Returns:
            (Tuple(Tensor, Tensor, Tensor)):
401
402
403
                The Tensor of waveforms with dimensions `(batch, time)`.
                The Tensor of labels with dimensions `(batch, seq)`.
                The Tensor of audio lengths with dimension `(batch,)`.
404
405
        """
        if self.pad:
406
            num_frames = max([sample[0].shape[1] for sample in batch])
407
        else:
408
            num_frames = min([sample[0].shape[1] for sample in batch])
409
410
411
        waveforms, labels, lengths = [], [], []
        for sample in batch:
            waveform, label, length = sample
412
413
            # The MFCC feature is 10ms per frame, while the HuBERT's transformer output
            # is 20ms per frame. Downsample the KMeans label if it's generated by MFCC features.
414
415
            if self.feature_type == "mfcc":
                label = label[::2]
416
            waveform, label, length = _crop_audio_label(waveform, label, length, num_frames, self.rand_crop)
417
418
419
            waveforms.append(waveform)
            lengths.append(length)
            labels.append(label)
420
421
422
423
424
425
426
427
428
        # make sure the shapes are the same if not apply zero-padding
        if not self.pad:
            assert all(
                [waveform.shape[0] == waveforms[0].shape[0] for waveform in waveforms]
            ), "The dimensions of the waveforms should be identical in the same batch."
            assert all(
                [label.shape[0] == labels[0].shape[0] for label in labels]
            ), "The dimensions of the labels should be identical in the same batch."
        waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
429
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
430
431
        lengths = torch.tensor(lengths)
        return waveforms, labels, lengths
432
433


434
def _get_lengths_librilightlimited(files: List[str], path: str, ext_audio: str) -> List[int]:
435
436
437
438
    lengths = []
    for file_path, fileid in files:
        speaker_id, chapter_id, utterance_id = fileid.split("-")
        # Load audio
439
440
        file_audio = f"{speaker_id}-{chapter_id}-{utterance_id}{ext_audio}"
        file_audio = os.path.join(path, file_path, speaker_id, chapter_id, file_audio)
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        length = torchaudio.info(file_audio).num_frames
        lengths.append(length)
    return lengths


def _get_lengths_librispeech(files: List[str], path: str, ext_audio: str) -> List[int]:
    lengths = []
    for file_path in files:
        speaker_id, chapter_id, utterance_id = file_path.split("-")
        fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id
        file_audio = fileid_audio + ext_audio
        file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
        length = torchaudio.info(file_audio).num_frames
        lengths.append(length)
    return lengths


class CollateFnLibriLightLimited:
    """The collate class for LibriSpeech or LibriLightLimited dataset."""

    def __call__(self, batch: List[Tuple[Tensor, int, str, int, int, int]]) -> Tuple[Tensor, Tensor, Tensor]:
        """
        Args:
            batch (List(Tuple(Tensor, int, str, int, int, int))):
                The list of tuples that contains
                waveform, sample_rate, transcript, speaker_id, chapter_id, and utterance_id.

        Returns:
            (Tuple(Tensor, Tensor, Tensor, Tensor)):
                The Tensor of waveforms with dimensions `(batch, time)`.
                The Tensor of labels with dimensions `(batch, seq)`.
                The Tensor of audio lengths with dimensions `(batch,)`.
                The Tensor of length lengths with dimensions `(batch,)`.

        """
        audio_sizes = [sample[0].shape[1] for sample in batch]
        audio_size = max(audio_sizes)
        waveforms, labels, audio_lengths, label_lengths = [], [], [], []
        label2id = _get_label2id()
        for sample in batch:
            waveform, transcript = sample[0], sample[2]
482
483
            # add one "|" symbol after the end of transcription as the word termination
            transcript = transcript + "|"
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
            label = torch.tensor([label2id[e] for e in transcript.replace(" ", "|").upper()])
            audio_length = waveform.size(1)
            label_length = label.size(0)
            waveforms.append(waveform)
            audio_lengths.append(audio_length)
            label_lengths.append(label_length)
            labels.append(label)

        data = torch.zeros(len(batch), audio_size)
        for i in range(len(waveforms)):
            data[i][0 : waveforms[i].shape[1]] = waveforms[i]
        audio_lengths = torch.tensor(audio_lengths)
        label_lengths = torch.tensor(label_lengths)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-1)
        return data, labels.int(), audio_lengths.int(), label_lengths.int()