transforms.py 42.4 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import math
4
5
6
from typing import Callable, Optional
from warnings import warn

David Pollack's avatar
David Pollack committed
7
import torch
8
from torch import Tensor
9
10
from torchaudio import functional as F
from torchaudio.compliance import kaldi
11

Jason Lian's avatar
Jason Lian committed
12

13
14
__all__ = [
    'Spectrogram',
15
    'GriffinLim',
16
    'AmplitudeToDB',
17
    'MelScale',
moto's avatar
moto committed
18
    'InverseMelScale',
19
20
21
22
23
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
24
25
    'ComplexNorm',
    'TimeStretch',
Tomás Osório's avatar
Tomás Osório committed
26
    'Fade',
27
28
    'FrequencyMasking',
    'TimeMasking',
wanglong001's avatar
wanglong001 committed
29
    'SlidingWindowCmn',
Artyom Astafurov's avatar
Artyom Astafurov committed
30
    'Vad',
31
32
33
]


34
class Spectrogram(torch.nn.Module):
35
    r"""Create a spectrogram from a audio signal.
36
37

    Args:
38
39
40
41
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
42
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
43
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
44
45
46
47
        power (float or None, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
48
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
49
    """
50
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
51

52
53
54
55
56
57
58
59
60
    def __init__(self,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: Optional[float] = 2.,
                 normalized: bool = False,
                 wkwargs: Optional[dict] = None) -> None:
61
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
62
        self.n_fft = n_fft
63
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
64
        # number of frequecies due to onesided=True in torch.stft
65
66
67
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
68
        self.register_buffer('window', window)
69
        self.pad = pad
PCerles's avatar
PCerles committed
70
        self.power = power
71
        self.normalized = normalized
72

73
    def forward(self, waveform: Tensor) -> Tensor:
74
        r"""
75
        Args:
76
            waveform (Tensor): Tensor of audio of dimension (..., time).
77
78

        Returns:
79
            Tensor: Dimension (..., freq, time), where freq is
Vincent QB's avatar
Vincent QB committed
80
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
81
            Fourier bins, and time is the number of window hops (n_frame).
82
        """
83
84
        return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
                             self.win_length, self.power, self.normalized)
85
86


87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

    .. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

    .. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.

    .. [3] D. W. Griffin and J. S. Lim,
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
105
106
107
108
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
109
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
110
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
111
        power (float, optional): Exponent for the magnitude spectrogram,
112
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
113
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
114
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
115
        momentum (float, optional): The momentum parameter for fast Griffin-Lim.
116
            Setting this to 0 recovers the original Griffin-Lim method.
117
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
118
        length (int, optional): Array length of the expected output. (Default: ``None``)
119
        rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
120
121
122
123
    """
    __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
                     'length', 'momentum', 'rand_init']

124
125
126
127
128
129
130
131
132
133
134
135
    def __init__(self,
                 n_fft: int = 400,
                 n_iter: int = 32,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: float = 2.,
                 normalized: bool = False,
                 wkwargs: Optional[dict] = None,
                 momentum: float = 0.99,
                 length: Optional[int] = None,
                 rand_init: bool = True) -> None:
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        super(GriffinLim, self).__init__()

        assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
        assert momentum > 0, 'momentum=%s < 0' % momentum

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.normalized = normalized
        self.length = length
        self.power = power
        self.momentum = momentum / (1 + momentum)
        self.rand_init = rand_init

153
154
155
156
157
158
159
160
161
162
    def forward(self, specgram: Tensor) -> Tensor:
        r"""
        Args:
            specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
            where freq is ``n_fft // 2 + 1``.

        Returns:
            Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
        """
        return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
163
164
165
                            self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)


166
class AmplitudeToDB(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
167
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
168

169
    This output depends on the maximum value in the input tensor, and so
170
171
172
173
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
174
        stype (str, optional): scale of input tensor ('power' or 'magnitude'). The
175
            power being the elementwise square of the magnitude. (Default: ``'power'``)
176
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
177
            is 80. (Default: ``None``)
178
179
180
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

181
    def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None:
182
        super(AmplitudeToDB, self).__init__()
183
        self.stype = stype
184
185
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
186
        self.top_db = top_db
187
188
189
190
191
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

192
    def forward(self, x: Tensor) -> Tensor:
193
        r"""Numerically stable implementation from Librosa.
194
195
196
        https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html

        Args:
197
            x (Tensor): Input tensor before being converted to decibel scale.
198
199

        Returns:
200
            Tensor: Output tensor in decibel scale.
201
        """
202
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
203
204


205
class MelScale(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
206
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion
207
    matrix.  This uses triangular filter banks.
208

209
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
210

211
    Args:
212
213
214
215
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
216
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
217
            if None is given.  See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
218
    """
219
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
220

221
222
223
224
225
226
    def __init__(self,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 n_stft: Optional[int] = None) -> None:
227
        super(MelScale, self).__init__()
228
        self.n_mels = n_mels
229
230
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
231
        self.f_min = f_min
232
233
234

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

235
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
engineerchuan's avatar
engineerchuan committed
236
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
237
        self.register_buffer('fb', fb)
238

239
    def forward(self, specgram: Tensor) -> Tensor:
240
241
        r"""
        Args:
242
            specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
243
244

        Returns:
245
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
246
        """
Vincent QB's avatar
Vincent QB committed
247
248
249

        # pack batch
        shape = specgram.size()
250
        specgram = specgram.reshape(-1, shape[-2], shape[-1])
Vincent QB's avatar
Vincent QB committed
251

252
        if self.fb.numel() == 0:
engineerchuan's avatar
engineerchuan committed
253
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
254
255
256
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
257

258
259
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
260
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
261
262

        # unpack batch
263
        mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
Vincent QB's avatar
Vincent QB committed
264

265
        return mel_specgram
266

267

moto's avatar
moto committed
268
269
270
271
272
273
274
275
276
class InverseMelScale(torch.nn.Module):
    r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
    matrix.  This uses triangular filter banks.

    It minimizes the euclidian norm between the input mel-spectrogram and the product between
    the estimated spectrogram and the filter banks using SGD.

    Args:
        n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
277
278
279
280
281
282
283
284
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
        max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
        tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
        tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
        sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
moto's avatar
moto committed
285
286
287
288
    """
    __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
                     'tolerance_change', 'sgdargs']

289
290
291
292
293
294
295
296
297
298
    def __init__(self,
                 n_stft: int,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 max_iter: int = 100000,
                 tolerance_loss: float = 1e-5,
                 tolerance_change: float = 1e-8,
                 sgdargs: Optional[dict] = None) -> None:
moto's avatar
moto committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        super(InverseMelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max or float(sample_rate // 2)
        self.f_min = f_min
        self.max_iter = max_iter
        self.tolerance_loss = tolerance_loss
        self.tolerance_change = tolerance_change
        self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

        fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
        self.register_buffer('fb', fb)

314
    def forward(self, melspec: Tensor) -> Tensor:
moto's avatar
moto committed
315
316
        r"""
        Args:
317
            melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
moto's avatar
moto committed
318
319

        Returns:
320
            Tensor: Linear scale spectrogram of size (..., freq, time)
moto's avatar
moto committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        """
        # pack batch
        shape = melspec.size()
        melspec = melspec.view(-1, shape[-2], shape[-1])

        n_mels, time = shape[-2], shape[-1]
        freq, _ = self.fb.size()  # (freq, n_mels)
        melspec = melspec.transpose(-1, -2)
        assert self.n_mels == n_mels

        specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
                              dtype=melspec.dtype, device=melspec.device)

        optim = torch.optim.SGD([specgram], **self.sgdargs)

        loss = float('inf')
        for _ in range(self.max_iter):
            optim.zero_grad()
            diff = melspec - specgram.matmul(self.fb)
            new_loss = diff.pow(2).sum(axis=-1).mean()
            # take sum over mel-frequency then average over other dimensions
            # so that loss threshold is applied par unit timeframe
            new_loss.backward()
            optim.step()
            specgram.data = specgram.data.clamp(min=0)

            new_loss = new_loss.item()
            if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
                break
            loss = new_loss

        specgram.requires_grad_(False)
        specgram = specgram.clamp(min=0).transpose(-1, -2)

        # unpack batch
        specgram = specgram.view(shape[:-2] + (freq, time))
        return specgram


360
class MelSpectrogram(torch.nn.Module):
361
362
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
363

364
    Sources
365
366
367
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
368

369
    Args:
370
371
372
373
374
375
376
377
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
378
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
379
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
380
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
381

382
    Example
383
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
384
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
385
    """
386
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
387

388
389
390
391
392
393
394
395
396
397
398
    def __init__(self,
                 sample_rate: int = 16000,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 pad: int = 0,
                 n_mels: int = 128,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 wkwargs: Optional[dict] = None) -> None:
399
400
401
402
403
404
405
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
406
        self.f_max = f_max
407
408
409
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
410
                                       pad=self.pad, window_fn=window_fn, power=2.,
411
                                       normalized=False, wkwargs=wkwargs)
412
        self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
413

414
    def forward(self, waveform: Tensor) -> Tensor:
415
        r"""
416
        Args:
417
            waveform (Tensor): Tensor of audio of dimension (..., time).
418
419

        Returns:
420
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
421
        """
422
423
424
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
425
426


427
class MFCC(torch.nn.Module):
428
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
PCerles's avatar
PCerles committed
429

430
431
432
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
433

434
435
436
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
437

438
    Args:
439
440
441
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
442
        norm (str, optional): norm to use. (Default: ``'ortho'``)
443
        log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
444
        melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
445
    """
446
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
447

448
449
450
451
452
453
454
    def __init__(self,
                 sample_rate: int = 16000,
                 n_mfcc: int = 40,
                 dct_type: int = 2,
                 norm: str = 'ortho',
                 log_mels: bool = False,
                 melkwargs: Optional[dict] = None) -> None:
455
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
456
457
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
458
            raise ValueError('DCT type not supported: {}'.format(dct_type))
459
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
460
461
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
462
        self.norm = norm
463
        self.top_db = 80.0
464
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
465
466

        if melkwargs is not None:
467
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
468
        else:
469
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
470
471
472

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
473
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
474
        self.register_buffer('dct_mat', dct_mat)
PCerles's avatar
PCerles committed
475
476
        self.log_mels = log_mels

477
    def forward(self, waveform: Tensor) -> Tensor:
478
        r"""
PCerles's avatar
PCerles committed
479
        Args:
480
            waveform (Tensor): Tensor of audio of dimension (..., time).
PCerles's avatar
PCerles committed
481
482

        Returns:
483
            Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
PCerles's avatar
PCerles committed
484
        """
Vincent QB's avatar
Vincent QB committed
485
486
487

        # pack batch
        shape = waveform.size()
488
        waveform = waveform.reshape(-1, shape[-1])
Vincent QB's avatar
Vincent QB committed
489

490
        mel_specgram = self.MelSpectrogram(waveform)
491
492
        if self.log_mels:
            log_offset = 1e-6
493
            mel_specgram = torch.log(mel_specgram + log_offset)
494
        else:
495
            mel_specgram = self.amplitude_to_DB(mel_specgram)
496
497
        # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
        # -> (channel, time, n_mfcc).tranpose(...)
498
        mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
499
500

        # unpack batch
501
        mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])
Vincent QB's avatar
Vincent QB committed
502

503
        return mfcc
504
505


506
class MuLawEncoding(torch.nn.Module):
507
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
508
509
510
511
512
513
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
514
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
515
    """
516
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
517

518
    def __init__(self, quantization_channels: int = 256) -> None:
519
        super(MuLawEncoding, self).__init__()
520
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
521

522
    def forward(self, x: Tensor) -> Tensor:
523
        r"""
David Pollack's avatar
David Pollack committed
524
        Args:
525
            x (Tensor): A signal to be encoded.
David Pollack's avatar
David Pollack committed
526
527

        Returns:
528
            x_mu (Tensor): An encoded signal.
David Pollack's avatar
David Pollack committed
529
        """
530
        return F.mu_law_encoding(x, self.quantization_channels)
531

Soumith Chintala's avatar
Soumith Chintala committed
532

533
class MuLawDecoding(torch.nn.Module):
534
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
535
536
537
538
539
540
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
541
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
542
    """
543
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
544

545
    def __init__(self, quantization_channels: int = 256) -> None:
546
        super(MuLawDecoding, self).__init__()
547
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
548

549
    def forward(self, x_mu: Tensor) -> Tensor:
550
        r"""
David Pollack's avatar
David Pollack committed
551
        Args:
552
            x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
David Pollack's avatar
David Pollack committed
553
554

        Returns:
555
            Tensor: The signal decoded.
David Pollack's avatar
David Pollack committed
556
        """
557
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
558
559
560


class Resample(torch.nn.Module):
561
    r"""Resample a signal from one frequency to another. A resampling method can be given.
jamarshon's avatar
jamarshon committed
562
563

    Args:
564
565
566
        orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
        new_freq (float, optional): The desired frequency. (Default: ``16000``)
        resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
567
    """
568

569
570
571
572
    def __init__(self,
                 orig_freq: int = 16000,
                 new_freq: int = 16000,
                 resampling_method: str = 'sinc_interpolation') -> None:
jamarshon's avatar
jamarshon committed
573
574
575
576
577
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

578
    def forward(self, waveform: Tensor) -> Tensor:
579
        r"""
jamarshon's avatar
jamarshon committed
580
        Args:
581
            waveform (Tensor): Tensor of audio of dimension (..., time).
jamarshon's avatar
jamarshon committed
582
583

        Returns:
584
            Tensor: Output signal of dimension (..., time).
jamarshon's avatar
jamarshon committed
585
586
        """
        if self.resampling_method == 'sinc_interpolation':
Vincent QB's avatar
Vincent QB committed
587
588
589
590
591
592
593
594
595
596
597

            # pack batch
            shape = waveform.size()
            waveform = waveform.view(-1, shape[-1])

            waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)

            # unpack batch
            waveform = waveform.view(shape[:-1] + waveform.shape[-1:])

            return waveform
jamarshon's avatar
jamarshon committed
598
599

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Vincent QB's avatar
Vincent QB committed
600
601


602
class ComplexNorm(torch.nn.Module):
603
604
    r"""Compute the norm of complex tensor input.

605
    Args:
606
        power (float, optional): Power of the norm. (Default: to ``1.0``)
607
608
609
    """
    __constants__ = ['power']

610
    def __init__(self, power: float = 1.0) -> None:
611
612
613
        super(ComplexNorm, self).__init__()
        self.power = power

614
    def forward(self, complex_tensor: Tensor) -> Tensor:
615
616
        r"""
        Args:
617
618
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.

619
        Returns:
620
            Tensor: norm of the input tensor, shape of `(..., )`.
621
622
623
624
        """
        return F.complex_norm(complex_tensor, self.power)


625
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
626
627
628
629
630
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
631
632
        win_length (int): The window length used for computing delta. (Default: ``5``)
        mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
Vincent QB's avatar
Vincent QB committed
633
634
635
    """
    __constants__ = ['win_length']

636
    def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
Vincent QB's avatar
Vincent QB committed
637
638
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
639
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
640

641
    def forward(self, specgram: Tensor) -> Tensor:
Vincent QB's avatar
Vincent QB committed
642
643
        r"""
        Args:
644
            specgram (Tensor): Tensor of audio of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
645
646

        Returns:
647
            Tensor: Tensor of deltas of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
648
649
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
650
651


652
class TimeStretch(torch.nn.Module):
653
654
655
    r"""Stretch stft in time without modifying pitch for a given rate.

    Args:
656
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
657
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
658
        fixed_rate (float or None, optional): rate to speed up or slow down by.
659
660
661
662
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
    """
    __constants__ = ['fixed_rate']

663
664
665
666
    def __init__(self,
                 hop_length: Optional[int] = None,
                 n_freq: int = 201,
                 fixed_rate: Optional[float] = None) -> None:
667
668
669
670
671
672
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
673
        self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
674

675
    def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
676
677
        r"""
        Args:
678
679
680
            complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2).
            overriding_rate (float or None, optional): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
681
682

        Returns:
683
            Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
684
        """
685
        assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
686
687
688
689
690
691
692
693
694
695
696
697

        if overriding_rate is None:
            rate = self.fixed_rate
            if rate is None:
                raise ValueError("If no fixed_rate is specified"
                                 ", must pass a valid rate to the forward method.")
        else:
            rate = overriding_rate

        if rate == 1.0:
            return complex_specgrams

698
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
699
700


Tomás Osório's avatar
Tomás Osório committed
701
702
703
704
705
706
707
708
709
class Fade(torch.nn.Module):
    r"""Add a fade in and/or fade out to an waveform.

    Args:
        fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``)
        fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``)
        fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
            "half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``)
    """
710
711
712
713
    def __init__(self,
                 fade_in_len: int = 0,
                 fade_out_len: int = 0,
                 fade_shape: str = "linear") -> None:
Tomás Osório's avatar
Tomás Osório committed
714
715
716
717
718
        super(Fade, self).__init__()
        self.fade_in_len = fade_in_len
        self.fade_out_len = fade_out_len
        self.fade_shape = fade_shape

719
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
720
721
        r"""
        Args:
722
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
723
724

        Returns:
725
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
726
727
        """
        waveform_length = waveform.size()[-1]
moto's avatar
moto committed
728
729
730
        device = waveform.device
        return self._fade_in(waveform_length).to(device) * \
            self._fade_out(waveform_length).to(device) * waveform
Tomás Osório's avatar
Tomás Osório committed
731

732
    def _fade_in(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
        fade = torch.linspace(0, 1, self.fade_in_len)
        ones = torch.ones(waveform_length - self.fade_in_len)

        if self.fade_shape == "linear":
            fade = fade

        if self.fade_shape == "exponential":
            fade = torch.pow(2, (fade - 1)) * fade

        if self.fade_shape == "logarithmic":
            fade = torch.log10(.1 + fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5

        return torch.cat((fade, ones)).clamp_(0, 1)

753
    def _fade_out(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
        fade = torch.linspace(0, 1, self.fade_out_len)
        ones = torch.ones(waveform_length - self.fade_out_len)

        if self.fade_shape == "linear":
            fade = - fade + 1

        if self.fade_shape == "exponential":
            fade = torch.pow(2, - fade) * (1 - fade)

        if self.fade_shape == "logarithmic":
            fade = torch.log10(1.1 - fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2 + math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5

        return torch.cat((ones, fade)).clamp_(0, 1)


775
class _AxisMasking(torch.nn.Module):
776
777
    r"""Apply masking to a spectrogram.

778
    Args:
779
780
781
        mask_param (int): Maximum possible length of the mask.
        axis (int): What dimension the mask is applied on.
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
782
            This option is applicable only when the input tensor is 4D.
783
784
785
    """
    __constants__ = ['mask_param', 'axis', 'iid_masks']

786
    def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
787
788
789
790
791
792

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks

793
    def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
794
795
        r"""
        Args:
796
            specgram (Tensor): Tensor of dimension (..., freq, time).
797
            mask_value (float): Value to assign to the masked columns.
798
799

        Returns:
800
            Tensor: Masked spectrogram of dimensions (..., freq, time).
801
802
803
804
805
        """
        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
        else:
806
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
807
808
809


class FrequencyMasking(_AxisMasking):
810
811
    r"""Apply masking to a spectrogram in the frequency domain.

812
813
814
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
815
816
817
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
818
    """
819
    def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
820
821
822
823
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
824
825
    r"""Apply masking to a spectrogram in the time domain.

826
827
828
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
829
830
831
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
832
    """
833
    def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
834
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
Tomás Osório's avatar
Tomás Osório committed
835
836
837
838
839
840
841


class Vol(torch.nn.Module):
    r"""Add a volume to an waveform.

    Args:
        gain (float): Interpreted according to the given gain_type:
Vincent QB's avatar
Vincent QB committed
842
843
844
845
            If ``gain_type`` = ``amplitude``, ``gain`` is a positive amplitude ratio.
            If ``gain_type`` = ``power``, ``gain`` is a power (voltage squared).
            If ``gain_type`` = ``db``, ``gain`` is in decibels.
        gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``)
Tomás Osório's avatar
Tomás Osório committed
846
847
    """

848
    def __init__(self, gain: float, gain_type: str = 'amplitude'):
Tomás Osório's avatar
Tomás Osório committed
849
850
851
852
853
854
855
        super(Vol, self).__init__()
        self.gain = gain
        self.gain_type = gain_type

        if gain_type in ['amplitude', 'power'] and gain < 0:
            raise ValueError("If gain_type = amplitude or power, gain must be positive.")

856
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
857
858
        r"""
        Args:
859
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
860
861

        Returns:
862
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
863
864
865
866
867
868
869
870
871
872
873
        """
        if self.gain_type == "amplitude":
            waveform = waveform * self.gain

        if self.gain_type == "db":
            waveform = F.gain(waveform, self.gain)

        if self.gain_type == "power":
            waveform = F.gain(waveform, 10 * math.log10(self.gain))

        return torch.clamp(waveform, -1, 1)
wanglong001's avatar
wanglong001 committed
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910


class SlidingWindowCmn(torch.nn.Module):
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
    """

    def __init__(self,
                 cmn_window: int = 600,
                 min_cmn_window: int = 100,
                 center: bool = False,
                 norm_vars: bool = False) -> None:
        super().__init__()
        self.cmn_window = cmn_window
        self.min_cmn_window = min_cmn_window
        self.center = center
        self.norm_vars = norm_vars

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Tensor of audio of dimension (..., time).
        """
        cmn_waveform = F.sliding_window_cmn(
            waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
        return cmn_waveform
Artyom Astafurov's avatar
Artyom Astafurov committed
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027


class Vad(torch.nn.Module):
    r"""Voice Activity Detector. Similar to SoX implementation.
    Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
    The algorithm currently uses a simple cepstral power measurement to detect voice,
    so may be fooled by other things, especially music.

    The effect can trim only from the front of the audio,
    so in order to trim from the back, the reverse effect must also be used.

    Args:
        sample_rate (int): Sample rate of audio signal.
        trigger_level (float, optional): The measurement level used to trigger activity detection.
            This may need to be cahnged depending on the noise level, signal level,
            and other characteristics of the input audio. (Default: 7.0)
        trigger_time (float, optional): The time constant (in seconds)
            used to help ignore short bursts of sound. (Default: 0.25)
        search_time (float, optional): The amount of audio (in seconds)
            to search for quieter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 1.0)
        allowed_gap (float, optional): The allowed gap (in seconds) between
            quiteter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 0.25)
        pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
            before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
        boot_time (float, optional) The algorithm (internally) uses adaptive noise
            estimation/reduction in order to detect the start of the wanted audio.
            This option sets the time for the initial noise estimate. (Default: 0.35)
        noise_up_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is increasing. (Default: 0.1)
        noise_down_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is decreasing. (Default: 0.01)
        noise_reduction_amount (float, optional) Amount of noise reduction to use in
            the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
        measure_freq (float, optional) Frequency of the algorithm’s
            processing/measurements. (Default: 20.0)
        measure_duration: (float, optional) Measurement duration.
            (Default: Twice the measurement period; i.e. with overlap.)
        measure_smooth_time (float, optional) Time constant used to smooth
            spectral measurements. (Default: 0.4)
        hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
            at the input to the detector algorithm. (Default: 50.0)
        lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
            at the input to the detector algorithm. (Default: 6000.0)
        hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
            in the detector algorithm. (Default: 150.0)
        lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
            in the detector algorithm. (Default: 2000.0)

    References:
        http://sox.sourceforge.net/sox.html
    """

    def __init__(self,
                 sample_rate: int,
                 trigger_level: float = 7.0,
                 trigger_time: float = 0.25,
                 search_time: float = 1.0,
                 allowed_gap: float = 0.25,
                 pre_trigger_time: float = 0.0,
                 boot_time: float = .35,
                 noise_up_time: float = .1,
                 noise_down_time: float = .01,
                 noise_reduction_amount: float = 1.35,
                 measure_freq: float = 20.0,
                 measure_duration: Optional[float] = None,
                 measure_smooth_time: float = .4,
                 hp_filter_freq: float = 50.,
                 lp_filter_freq: float = 6000.,
                 hp_lifter_freq: float = 150.,
                 lp_lifter_freq: float = 2000.) -> None:
        super().__init__()

        self.sample_rate = sample_rate
        self.trigger_level = trigger_level
        self.trigger_time = trigger_time
        self.search_time = search_time
        self.allowed_gap = allowed_gap
        self.pre_trigger_time = pre_trigger_time
        self.boot_time = boot_time
        self.noise_up_time = noise_up_time
        self.noise_down_time = noise_up_time
        self.noise_reduction_amount = noise_reduction_amount
        self.measure_freq = measure_freq
        self.measure_duration = measure_duration
        self.measure_smooth_time = measure_smooth_time
        self.hp_filter_freq = hp_filter_freq
        self.lp_filter_freq = lp_filter_freq
        self.hp_lifter_freq = hp_lifter_freq
        self.lp_lifter_freq = lp_lifter_freq

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension `(..., time)`
        """
        return F.vad(
            waveform=waveform,
            sample_rate=self.sample_rate,
            trigger_level=self.trigger_level,
            trigger_time=self.trigger_time,
            search_time=self.search_time,
            allowed_gap=self.allowed_gap,
            pre_trigger_time=self.pre_trigger_time,
            boot_time=self.boot_time,
            noise_up_time=self.noise_up_time,
            noise_down_time=self.noise_up_time,
            noise_reduction_amount=self.noise_reduction_amount,
            measure_freq=self.measure_freq,
            measure_duration=self.measure_duration,
            measure_smooth_time=self.measure_smooth_time,
            hp_filter_freq=self.hp_filter_freq,
            lp_filter_freq=self.lp_filter_freq,
            hp_lifter_freq=self.hp_lifter_freq,
            lp_lifter_freq=self.lp_lifter_freq,
        )