functional.py 10.1 KB
Newer Older
1
import math
Jason Lian's avatar
Jason Lian committed
2
3
import torch

Jason Lian's avatar
Jason Lian committed
4

Jason Lian's avatar
pre  
Jason Lian committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
__all__ = [
    'scale',
    'pad_trim',
    'downmix_mono',
    'LC2CL',
    'spectrogram',
    'create_fb_matrix',
    'spectrogram_to_DB',
    'create_dct',
    'BLC2CBL',
    'mu_law_encoding',
    'mu_law_expanding'
]

Jason Lian's avatar
Jason Lian committed
19

20
@torch.jit.script
Jason Lian's avatar
Jason Lian committed
21
22
def scale(tensor, factor):
    # type: (Tensor, int) -> Tensor
Jason Lian's avatar
Jason Lian committed
23
24
25
26
27
28
    """Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
    to a floating point number between -1.0 and 1.0.  Note the 16-bit number is
    called the "bit depth" or "precision", not to be confused with "bit rate".

    Inputs:
        tensor (Tensor): Tensor of audio of size (Samples x Channels)
Jason Lian's avatar
Jason Lian committed
29
        factor (int): Maximum value of input tensor
Jason Lian's avatar
Jason Lian committed
30
31

    Outputs:
Jason Lian's avatar
Jason Lian committed
32
        Tensor: Scaled by the scale factor
Jason Lian's avatar
Jason Lian committed
33
    """
34
    if not tensor.is_floating_point():
Jason Lian's avatar
Jason Lian committed
35
36
37
38
        tensor = tensor.to(torch.float32)

    return tensor / factor

Jason Lian's avatar
more  
Jason Lian committed
39

40
@torch.jit.script
Jason Lian's avatar
Jason Lian committed
41
42
def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
    # type: (Tensor, int, int, int, float) -> Tensor
Jason Lian's avatar
Jason Lian committed
43
44
45
46
47
48
49
50
51
52
53
54
    """Pad/Trim a 2d-Tensor (Signal or Labels)

    Inputs:
        tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
        ch_dim (int): Dimension of channel (not size)
        max_len (int): Length to which the tensor will be padded
        len_dim (int): Dimension of length (not size)
        fill_value (float): Value to fill in

    Outputs:
        Tensor: Padded/trimmed tensor
    """
Jason Lian's avatar
Jason Lian committed
55
    if max_len > tensor.size(len_dim):
56
        # array of [padding_left, padding_right, padding_top, padding_bottom]
Jason Lian's avatar
Jason Lian committed
57
58
        # so pad similar to append (aka only right/bottom) and do not pad
        # the length dimension. assumes equal sizes of padding.
Jason Lian's avatar
Jason Lian committed
59
60
61
        padding = [max_len - tensor.size(len_dim)
                   if (i % 2 == 1) and (i // 2 != len_dim)
                   else 0
62
63
64
                   for i in [0, 1, 2, 3]]
        # TODO add "with torch.no_grad():" back when JIT supports it
        tensor = torch.nn.functional.pad(tensor, padding, "constant", fill_value)
Jason Lian's avatar
Jason Lian committed
65
66
67
68
    elif max_len < tensor.size(len_dim):
        tensor = tensor.narrow(len_dim, 0, max_len)
    return tensor

Jason Lian's avatar
more  
Jason Lian committed
69

70
@torch.jit.script
Jason Lian's avatar
Jason Lian committed
71
72
def downmix_mono(tensor, ch_dim):
    # type: (Tensor, int) -> Tensor
Jason Lian's avatar
Jason Lian committed
73
74
75
76
77
78
79
80
81
    """Downmix any stereo signals to mono.

    Inputs:
        tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
        ch_dim (int): Dimension of channel (not size)

    Outputs:
        Tensor: Mono signal
    """
82
    if not tensor.is_floating_point():
Jason Lian's avatar
Jason Lian committed
83
84
85
86
87
        tensor = tensor.to(torch.float32)

    tensor = torch.mean(tensor, ch_dim, True)
    return tensor

Jason Lian's avatar
more  
Jason Lian committed
88

89
@torch.jit.script
Jason Lian's avatar
more  
Jason Lian committed
90
def LC2CL(tensor):
Jason Lian's avatar
Jason Lian committed
91
    # type: (Tensor) -> Tensor
Jason Lian's avatar
Jason Lian committed
92
93
94
95
96
97
98
99
    """Permute a 2d tensor from samples (n x c) to (c x n)

    Inputs:
        tensor (Tensor): Tensor of audio signal with shape (LxC)

    Outputs:
        Tensor: Tensor of audio signal with shape (CxL)
    """
Jason Lian's avatar
Jason Lian committed
100
101
    return tensor.transpose(0, 1).contiguous()

Jason Lian's avatar
more  
Jason Lian committed
102

103
104
105
106
107
108
def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
    # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
    return torch.stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided)


@torch.jit.script
Jason Lian's avatar
Jason Lian committed
109
110
def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
    # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
Jason Lian's avatar
Jason Lian committed
111
112
113
114
115
116
117
118
    """Create a spectrogram from a raw audio signal

    Inputs:
        sig (Tensor): Tensor of audio of size (c, n)
        pad (int): two sided padding of signal
        window (Tensor): window_tensor
        n_fft (int): size of fft
        hop (int): length of hop between STFT windows
Jason Lian's avatar
Jason Lian committed
119
        ws (int): window size
Jason Lian's avatar
Jason Lian committed
120
121
122
123
124
125
126
127
128
129
130
        power (int > 0 ) : Exponent for the magnitude spectrogram,
                        e.g., 1 for energy, 2 for power, etc.
        normalize (bool) : whether to normalize by magnitude after stft


    Outputs:
        Tensor: channels x hops x n_fft (c, l, f), where channels
            is unchanged, hops is the number of hops, and n_fft is the
            number of fourier bins, which should be the window size divided
            by 2 plus 1.
    """
Jason Lian's avatar
Jason Lian committed
131
132
133
    assert sig.dim() == 2

    if pad > 0:
134
135
        # TODO add "with torch.no_grad():" back when JIT supports it
        sig = torch.nn.functional.pad(sig, (pad, pad), "constant")
Jason Lian's avatar
Jason Lian committed
136
137

    # default values are consistent with librosa.core.spectrum._spectrogram
138
139
140
    spec_f = _stft(sig, n_fft, hop, ws, window,
                   True, 'reflect', False, True).transpose(1, 2)

Jason Lian's avatar
Jason Lian committed
141
142
143
144
    if normalize:
        spec_f /= window.pow(2).sum().sqrt()
    spec_f = spec_f.pow(power).sum(-1)  # get power of "complex" tensor (c, l, n_fft)
    return spec_f
Jason Lian's avatar
more  
Jason Lian committed
145
146


147
@torch.jit.script
Jason Lian's avatar
more  
Jason Lian committed
148
149
150
151
def create_fb_matrix(n_stft, f_min, f_max, n_mels):
    # type: (int, float, float, int) -> Tensor
    """ Create a frequency bin conversion matrix.

Jason Lian's avatar
Jason Lian committed
152
    Inputs:
Jason Lian's avatar
more  
Jason Lian committed
153
        n_stft (int): number of filter banks from spectrogram
Jason Lian's avatar
Jason Lian committed
154
155
156
157
158
159
        f_min (float): minimum frequency
        f_max (float): maximum frequency
        n_mels (int): number of mel bins

    Outputs:
        Tensor: triangular filter banks (fb matrix)
Jason Lian's avatar
more  
Jason Lian committed
160

161
    """
Jason Lian's avatar
more  
Jason Lian committed
162
163
164
    # get stft freq bins
    stft_freqs = torch.linspace(f_min, f_max, n_stft)
    # calculate mel freq bins
165
166
167
    # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
    m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.))
    m_max = 2595. * math.log10(1. + (f_max / 700.))
Jason Lian's avatar
more  
Jason Lian committed
168
    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
169
170
    # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
    f_pts = 700. * (10**(m_pts / 2595.) - 1.)
Jason Lian's avatar
more  
Jason Lian committed
171
172
173
174
    # calculate the difference between each mel point and each stft freq point in hertz
    f_diff = f_pts[1:] - f_pts[:-1]  # (n_mels + 1)
    slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1)  # (n_stft, n_mels + 2)
    # create overlapping triangles
175
    z = torch.zeros(1)
Jason Lian's avatar
more  
Jason Lian committed
176
177
178
179
180
181
    down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1]  # (n_stft, n_mels)
    up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_stft, n_mels)
    fb = torch.max(z, torch.min(down_slopes, up_slopes))
    return fb


182
@torch.jit.script
Jason Lian's avatar
Jason Lian committed
183
def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
Jason Lian's avatar
more  
Jason Lian committed
184
    # type: (Tensor, float, float, float, Optional[float]) -> Tensor
Jason Lian's avatar
Jason Lian committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    """Turns a spectrogram from the power/amplitude scale to the decibel scale.

    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Inputs:
        spec (Tensor): normal STFT
        multiplier (float): use 10. for power and 20. for amplitude
        amin (float): number to clamp spec
        db_multiplier (float): log10(max(reference value and amin))
        top_db (Optional[float]): minimum negative cut-off in decibels.  A reasonable number
            is 80.

    Outputs:
        Tensor: spectrogram in DB
    """
Jason Lian's avatar
more  
Jason Lian committed
202
203
204
205
    spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin))
    spec_db -= multiplier * db_multiplier

    if top_db is not None:
206
207
208
        new_spec_db_max = torch.tensor(float(spec_db.max()) - top_db, dtype=spec_db.dtype, device=spec_db.device)
        spec_db = torch.max(spec_db, new_spec_db_max)

Jason Lian's avatar
more  
Jason Lian committed
209
    return spec_db
Jason Lian's avatar
more  
Jason Lian committed
210
211


212
@torch.jit.script
Jason Lian's avatar
more  
Jason Lian committed
213
def create_dct(n_mfcc, n_mels, norm):
214
    # type: (int, int, Optional[str]) -> Tensor
Jason Lian's avatar
more  
Jason Lian committed
215
216
217
    """
    Creates a DCT transformation matrix with shape (num_mels, num_mfcc),
    normalized depending on norm
Jason Lian's avatar
Jason Lian committed
218
219
220
221

    Inputs:
        n_mfcc (int) : number of mfc coefficients to retain
        n_mels (int): number of MEL bins
222
        norm (Optional[str]) : norm to use (either 'ortho' or None)
Jason Lian's avatar
Jason Lian committed
223
224
225

    Outputs:
        Tensor: The transformation matrix, to be right-multiplied to row-wise data.
Jason Lian's avatar
more  
Jason Lian committed
226
227
228
229
    """
    outdim = n_mfcc
    dim = n_mels
    # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
230
231
232
233
234
    n = torch.arange(dim)
    k = torch.arange(outdim)[:, None]
    dct = torch.cos(math.pi / float(dim) * (n + 0.5) * k)
    if norm is None:
        dct *= 2.0
Jason Lian's avatar
more  
Jason Lian committed
235
    else:
236
237
238
        assert norm == 'ortho'
        dct[0] *= 1.0 / math.sqrt(2.0)
        dct *= math.sqrt(2.0 / float(dim))
239
    return dct.t()
Jason Lian's avatar
more  
Jason Lian committed
240
241


242
@torch.jit.script
Jason Lian's avatar
more  
Jason Lian committed
243
244
def BLC2CBL(tensor):
    # type: (Tensor) -> Tensor
Jason Lian's avatar
Jason Lian committed
245
246
247
248
249
250
251
252
253
    """Permute a 3d tensor from Bands x Sample length x Channels to Channels x
       Bands x Samples length

    Inputs:
        tensor (Tensor): Tensor of spectrogram with shape (BxLxC)

    Outputs:
        Tensor: Tensor of spectrogram with shape (CxBxL)
    """
Jason Lian's avatar
more  
Jason Lian committed
254
255
256
    return tensor.permute(2, 0, 1).contiguous()


257
@torch.jit.script
Jason Lian's avatar
more  
Jason Lian committed
258
def mu_law_encoding(x, qc):
259
    # type: (Tensor, int) -> Tensor
Jason Lian's avatar
Jason Lian committed
260
261
262
263
264
265
266
267
268
269
270
271
272
    """Encode signal based on mu-law companding.  For more info see the
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Inputs:
        x (Tensor): Input tensor
        qc (int): Number of channels (i.e. quantization channels)

    Outputs:
        Tensor: Input after mu-law companding
    """
273
    assert isinstance(x, torch.Tensor), 'mu_law_encoding expects a Tensor'
Jason Lian's avatar
more  
Jason Lian committed
274
    mu = qc - 1.
275
    if not x.is_floating_point():
276
277
278
279
        x = x.to(torch.float)
    mu = torch.tensor(mu, dtype=x.dtype)
    x_mu = torch.sign(x) * torch.log1p(mu *
                                       torch.abs(x)) / torch.log1p(mu)
Jason Lian's avatar
Jason Lian committed
280
    x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
Jason Lian's avatar
more  
Jason Lian committed
281
282
283
    return x_mu


284
@torch.jit.script
Jason Lian's avatar
pre  
Jason Lian committed
285
def mu_law_expanding(x_mu, qc):
286
    # type: (Tensor, int) -> Tensor
Jason Lian's avatar
Jason Lian committed
287
288
289
290
291
292
293
    """Decode mu-law encoded signal.  For more info see the
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Inputs:
Jason Lian's avatar
pre  
Jason Lian committed
294
        x_mu (Tensor): Input tensor
Jason Lian's avatar
Jason Lian committed
295
296
297
298
299
        qc (int): Number of channels (i.e. quantization channels)

    Outputs:
        Tensor: Input after decoding
    """
300
    assert isinstance(x_mu, torch.Tensor), 'mu_law_expanding expects a Tensor'
Jason Lian's avatar
more  
Jason Lian committed
301
    mu = qc - 1.
302
    if not x_mu.is_floating_point():
303
304
305
306
        x_mu = x_mu.to(torch.float)
    mu = torch.tensor(mu, dtype=x_mu.dtype)
    x = ((x_mu) / mu) * 2 - 1.
    x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
Jason Lian's avatar
more  
Jason Lian committed
307
    return x