utils.py 2.2 KB
Newer Older
mayp777's avatar
UPDATE  
mayp777 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Dict, List, Tuple

import torch
from torch import Tensor


class CollateFnL3DAS22:
    """The collate class for L3DAS22 dataset.
    Args:
        pad (bool): If ``True``, the waveforms and labels will be padded to the
            max length in the mini-batch. If ``pad`` is False, the waveforms
            and labels will be cropped to the minimum length in the mini-batch.
            (Default: False)
        rand_crop (bool): if ``True``, the starting index of the waveform
            and label is random if the length is longer than the minimum
            length in the mini-batch.
    """

    def __init__(
        self,
        audio_length: int = 16000 * 4,
        rand_crop: bool = True,
    ) -> None:
        self.audio_length = audio_length
        self.rand_crop = rand_crop

    def __call__(self, batch: List[Tuple[Tensor, Tensor, int, str]]) -> Dict:
        """
        Args:
            batch (List[Tuple(Tensor, Tensor, int)]):
                The list of tuples that contains:
                - mixture waveforms
                - clean waveform
                - sample rate
                - transcript

        Returns:
            Dictionary
                "input": Tuple of waveforms and lengths.
                    waveforms Tensor with dimensions `(batch, time)`.
                    lengths Tensor with dimension `(batch,)`.
                "label": None
        """
        waveforms_noisy, waveforms_clean = [], []
        for sample in batch:
            waveform_noisy, waveform_clean, _SAMPLE_RATE, transcript = sample
            if self.rand_crop:
                diff = waveform_noisy.size(-1) - self.audio_length
                frame_offset = torch.randint(diff, size=(1,))
            else:
                frame_offset = 0
            waveform_noisy = waveform_noisy[:, frame_offset : frame_offset + self.audio_length]
            waveform_clean = waveform_clean[:, frame_offset : frame_offset + self.audio_length]
            waveforms_noisy.append(waveform_noisy.unsqueeze(0))
            waveforms_clean.append(waveform_clean)
        waveforms_noisy = torch.cat(waveforms_noisy)
        waveforms_clean = torch.cat(waveforms_clean)
        return waveforms_noisy, waveforms_clean