"vscode:/vscode.git/clone" did not exist on "fc78640e00e39520fa7126789d23369d2f104d0c"
processing.py 1.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import librosa
import torch
import torch.nn as nn


# TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved
class LinearToMel(nn.Module):
    def __init__(self, sample_rate, n_fft, n_mels, fmin, htk=False, norm="slaney"):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.fmin = fmin
        self.htk = htk
        self.norm = norm

    def forward(self, specgram):
        specgram = librosa.feature.melspectrogram(
            S=specgram.squeeze(0).numpy(),
            sr=self.sample_rate,
            n_fft=self.n_fft,
            n_mels=self.n_mels,
            fmin=self.fmin,
            htk=self.htk,
            norm=self.norm,
        )
        return torch.from_numpy(specgram)


class NormalizeDB(nn.Module):
    r"""Normalize the spectrogram with a minimum db value
    """

    def __init__(self, min_level_db):
        super().__init__()
        self.min_level_db = min_level_db

    def forward(self, specgram):
        specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5))
        return torch.clamp(
            (self.min_level_db - specgram) / self.min_level_db, min=0, max=1
        )


def normalized_waveform_to_bits(waveform, bits):
    r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]
    """

    assert abs(waveform).max() <= 1.0
    waveform = (waveform + 1.0) * (2 ** bits - 1) / 2
    return torch.clamp(waveform, 0, 2 ** bits - 1).int()


def bits_to_normalized_waveform(label, bits):
    r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1]
    """

    return 2 * label / (2 ** bits - 1.0) - 1.0