Unverified Commit c4565245 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

extend batch support (#391)

* extend batch support

closes #383

* function for batch test.

* set seed.

* adjust tolerance for griffinlim.
parent 45498f26
...@@ -103,6 +103,25 @@ class TestFunctional(unittest.TestCase): ...@@ -103,6 +103,25 @@ class TestFunctional(unittest.TestCase):
self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5)) self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))
def test_batch_griffinlim(self):
torch.random.manual_seed(42)
tensor = torch.rand((1, 201, 6))
n_fft = 400
ws = 400
hop = 200
window = torch.hann_window(ws)
power = 2
normalize = False
momentum = 0.99
n_iter = 32
length = 1000
self._test_batch(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
)
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length) computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
...@@ -126,22 +145,17 @@ class TestFunctional(unittest.TestCase): ...@@ -126,22 +145,17 @@ class TestFunctional(unittest.TestCase):
win_length = 2 * 7 + 1 win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time) specgram = torch.randn(channel, n_mfcc, time)
computed = F.compute_deltas(specgram, win_length=win_length) computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length) _test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)
def test_batch_pitch(self): def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) waveform, sample_rate = torchaudio.load(self.test_filepath)
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
# Single then transform then batch def test_jit_pitch(self):
expected = F.detect_pitch_frequency(waveform, sample_rate) waveform, sample_rate = torchaudio.load(self.test_filepath)
expected = expected.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = F.detect_pitch_frequency(waveform, sample_rate)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate) _test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
...@@ -157,7 +171,6 @@ class TestFunctional(unittest.TestCase): ...@@ -157,7 +171,6 @@ class TestFunctional(unittest.TestCase):
for data_size in self.data_sizes: for data_size in self.data_sizes:
for i in range(self.number_of_trials): for i in range(self.number_of_trials):
# Non-batch
sound = common_utils.random_float_tensor(i, data_size) sound = common_utils.random_float_tensor(i, data_size)
stft = torch.stft(sound, **kwargs) stft = torch.stft(sound, **kwargs)
...@@ -165,14 +178,6 @@ class TestFunctional(unittest.TestCase): ...@@ -165,14 +178,6 @@ class TestFunctional(unittest.TestCase):
self._compare_estimate(sound, estimate) self._compare_estimate(sound, estimate)
# Batch
stft = torch.stft(sound, **kwargs)
stft = stft.repeat(3, 1, 1, 1, 1)
sound = sound.repeat(3, 1, 1)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
self._compare_estimate(sound, estimate)
def test_istft_is_inverse_of_stft1(self): def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided # hann_window, centered, normalized, onesided
kwargs1 = { kwargs1 = {
...@@ -389,6 +394,16 @@ class TestFunctional(unittest.TestCase): ...@@ -389,6 +394,16 @@ class TestFunctional(unittest.TestCase):
data_size = (2, 7, 3, 2) data_size = (2, 7, 3, 2)
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8) self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
def test_batch_istft(self):
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
self._test_batch(F.istft, stft, n_fft=4, length=4)
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0): def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
# Using a decorator here causes parametrize to fail on Python 2 # Using a decorator here causes parametrize to fail on Python 2
if not IMPORT_LIBROSA: if not IMPORT_LIBROSA:
...@@ -489,22 +504,63 @@ class TestFunctional(unittest.TestCase): ...@@ -489,22 +504,63 @@ class TestFunctional(unittest.TestCase):
self.assertFalse(s) self.assertFalse(s)
# Convert to stereo and batch for testing purposes # Convert to stereo and batch for testing purposes
freq = freq.repeat(3, 2, 1, 1) self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
waveform = waveform.repeat(3, 2, 1, 1)
freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate) def _test_batch_shape(self, functional, tensor, *args, **kwargs):
assert torch.allclose(freq, freq2, atol=1e-5) kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol
def _test_batch(self, functional): if 'rtol' in kwargs:
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
# Single then transform then batch # Single then transform then batch
expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1)
# Batch then transform torch.random.manual_seed(42)
waveform = waveform.unsqueeze(0).repeat(3, 1, 1) expected = functional(tensor.clone(), *args, **kwargs)
computed = functional(waveform) expected = expected.unsqueeze(0).unsqueeze(0)
# 1-Batch then transform
tensors = tensor.unsqueeze(0).unsqueeze(0)
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
self._compare_estimate(computed, expected, **kwargs_compare)
return tensors, expected
def _test_batch(self, functional, tensor, *args, **kwargs):
tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs)
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
# 3-Batch then transform
ind = [3] + [1] * (int(tensors.dim()) - 1)
tensors = tensor.repeat(*ind)
ind = [3] + [1] * (int(expected.dim()) - 1)
expected = expected.repeat(*ind)
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
def test_torchscript_create_fb_matrix(self): def test_torchscript_create_fb_matrix(self):
......
...@@ -381,6 +381,19 @@ class Tester(unittest.TestCase): ...@@ -381,6 +381,19 @@ class Tester(unittest.TestCase):
computed = transform(specgram) computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786)
# Single then transform then batch
expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_compute_deltas(self): def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786) specgram = torch.randn(2, 31, 2786)
...@@ -440,6 +453,30 @@ class Tester(unittest.TestCase): ...@@ -440,6 +453,30 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected)) self.assertTrue(torch.allclose(computed, expected))
def test_batch_melspectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_mfcc(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected, atol=1e-5))
def test_scriptmodule_TimeStretch(self): def test_scriptmodule_TimeStretch(self):
n_freq = 400 n_freq = 400
hop_length = 512 hop_length = 512
......
...@@ -99,7 +99,7 @@ def istft( ...@@ -99,7 +99,7 @@ def istft(
Args: Args:
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. it has a size of either (..., fft_size, n_frame, 2) column is a window. It has a size of either (..., fft_size, n_frame, 2)
n_fft (int): Size of Fourier transform n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames. hop_length (Optional[int]): The distance between neighboring sliding window frames.
(Default: ``win_length // 4``) (Default: ``win_length // 4``)
...@@ -230,7 +230,7 @@ def spectrogram( ...@@ -230,7 +230,7 @@ def spectrogram(
The spectrogram can be either magnitude-only or complex. The spectrogram can be either magnitude-only or complex.
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time) waveform (torch.Tensor): Tensor of audio of dimension (..., time)
pad (int): Two sided padding of signal pad (int): Two sided padding of signal
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT n_fft (int): Size of FFT
...@@ -242,8 +242,8 @@ def spectrogram( ...@@ -242,8 +242,8 @@ def spectrogram(
normalized (bool): Whether to normalize by magnitude after stft normalized (bool): Whether to normalize by magnitude after stft
Returns: Returns:
torch.Tensor: Dimension (..., channel, freq, time), where channel torch.Tensor: Dimension (..., freq, time), freq is
is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame). Fourier bins, and time is the number of window hops (n_frame).
""" """
...@@ -292,7 +292,7 @@ def griffinlim( ...@@ -292,7 +292,7 @@ def griffinlim(
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Args: Args:
specgram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (channel, freq, frames) specgram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
where freq is ``n_fft // 2 + 1``. where freq is ``n_fft // 2 + 1``.
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
...@@ -310,11 +310,15 @@ def griffinlim( ...@@ -310,11 +310,15 @@ def griffinlim(
rand_init (bool): Initializes phase randomly if True, to zero otherwise. (Default: ``True``) rand_init (bool): Initializes phase randomly if True, to zero otherwise. (Default: ``True``)
Returns: Returns:
torch.Tensor: waveform of (channel, time), where time equals the ``length`` parameter if given. torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
""" """
assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
assert momentum > 0, 'momentum=%s < 0' % momentum assert momentum > 0, 'momentum=%s < 0' % momentum
# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = specgram.pow(1 / power) specgram = specgram.pow(1 / power)
# randomly initialize the phase # randomly initialize the phase
...@@ -351,12 +355,17 @@ def griffinlim( ...@@ -351,12 +355,17 @@ def griffinlim(
angles = angles.div_(complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles)) angles = angles.div_(complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles))
# Return the final phase estimates # Return the final phase estimates
return istft(specgram * angles, waveform = istft(specgram * angles,
n_fft=n_fft, n_fft=n_fft,
hop_length=hop_length, hop_length=hop_length,
win_length=win_length, win_length=win_length,
window=window, window=window,
length=length) length=length)
# unpack batch
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
return waveform
def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
...@@ -699,7 +708,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): ...@@ -699,7 +708,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
https://en.wikipedia.org/wiki/Digital_biquad_filter https://en.wikipedia.org/wiki/Digital_biquad_filter
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
b0 (float): numerator coefficient of current input, x[n] b0 (float): numerator coefficient of current input, x[n]
b1 (float): numerator coefficient of input one time step ago x[n-1] b1 (float): numerator coefficient of input one time step ago x[n-1]
b2 (float): numerator coefficient of input two time steps ago x[n-2] b2 (float): numerator coefficient of input two time steps ago x[n-2]
...@@ -708,7 +717,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): ...@@ -708,7 +717,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
a2 (float): denominator coefficient of current output y[n-2] a2 (float): denominator coefficient of current output y[n-2]
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)` output_waveform (torch.Tensor): Dimension of `(..., time)`
""" """
device = waveform.device device = waveform.device
...@@ -732,13 +741,13 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): ...@@ -732,13 +741,13 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation. r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor Q (float): https://en.wikipedia.org/wiki/Q_factor
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)` output_waveform (torch.Tensor): Dimension of `(..., time)`
""" """
GAIN = 1. GAIN = 1.
...@@ -761,13 +770,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): ...@@ -761,13 +770,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation. r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor Q (float): https://en.wikipedia.org/wiki/Q_factor
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)` output_waveform (torch.Tensor): Dimension of `(..., time)`
""" """
GAIN = 1. GAIN = 1.
...@@ -790,14 +799,14 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): ...@@ -790,14 +799,14 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation. r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation.
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
center_freq (float): filter's central frequency center_freq (float): filter's central frequency
gain (float): desired gain at the boost (or attenuation) in dB gain (float): desired gain at the boost (or attenuation) in dB
q_factor (float): https://en.wikipedia.org/wiki/Q_factor q_factor (float): https://en.wikipedia.org/wiki/Q_factor
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)` output_waveform (torch.Tensor): Dimension of `(..., time)`
""" """
w0 = 2 * math.pi * center_freq / sample_rate w0 = 2 * math.pi * center_freq / sample_rate
A = math.exp(gain / 40.0 * math.log(10)) A = math.exp(gain / 40.0 * math.log(10))
...@@ -886,7 +895,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): ...@@ -886,7 +895,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
# unpack batch # unpack batch
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:]) specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
return specgram.reshape(shape[:-2] + specgram.shape[-2:]) return specgram
def compute_deltas(specgram, win_length=5, mode="replicate"): def compute_deltas(specgram, win_length=5, mode="replicate"):
...@@ -946,7 +955,7 @@ def gain(waveform, gain_db=1.0): ...@@ -946,7 +955,7 @@ def gain(waveform, gain_db=1.0):
r"""Apply amplification or attenuation to the whole waveform. r"""Apply amplification or attenuation to the whole waveform.
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time). waveform (torch.Tensor): Tensor of audio of dimension (..., time).
gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`). gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).
Returns: Returns:
...@@ -999,7 +1008,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): ...@@ -999,7 +1008,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
The relationship of probabilities of results follows a bell-shaped, The relationship of probabilities of results follows a bell-shaped,
or Gaussian curve, typical of dither generated by analog sources. or Gaussian curve, typical of dither generated by analog sources.
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time) waveform (torch.Tensor): Tensor of audio of dimension (..., time)
probability_density_function (string): The density function of a probability_density_function (string): The density function of a
continuous random variable (Default: `TPDF`) continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF` Options: Triangular Probability Density Function - `TPDF`
...@@ -1008,6 +1017,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): ...@@ -1008,6 +1017,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
Returns: Returns:
torch.Tensor: waveform dithered with TPDF torch.Tensor: waveform dithered with TPDF
""" """
# pack batch
shape = waveform.size() shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1]) waveform = waveform.reshape(-1, shape[-1])
...@@ -1047,6 +1058,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): ...@@ -1047,6 +1058,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
quantised_signal_scaled = torch.round(signal_scaled_dis) quantised_signal_scaled = torch.round(signal_scaled_dis)
quantised_signal = quantised_signal_scaled / down_scaling quantised_signal = quantised_signal_scaled / down_scaling
# unpack batch
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:]) return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
...@@ -1056,7 +1069,7 @@ def dither(waveform, density_function="TPDF", noise_shaping=False): ...@@ -1056,7 +1069,7 @@ def dither(waveform, density_function="TPDF", noise_shaping=False):
particular bit-depth by eliminating nonlinear truncation distortion particular bit-depth by eliminating nonlinear truncation distortion
(i.e. adding minimally perceived noise to mask distortion caused by quantization). (i.e. adding minimally perceived noise to mask distortion caused by quantization).
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time) waveform (torch.Tensor): Tensor of audio of dimension (..., time)
density_function (string): The density function of a density_function (string): The density function of a
continuous random variable (Default: `TPDF`) continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF` Options: Triangular Probability Density Function - `TPDF`
......
...@@ -62,11 +62,11 @@ class Spectrogram(torch.nn.Module): ...@@ -62,11 +62,11 @@ class Spectrogram(torch.nn.Module):
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time) waveform (torch.Tensor): Tensor of audio of dimension (..., time)
Returns: Returns:
torch.Tensor: Dimension (channel, freq, time), where channel torch.Tensor: Dimension (..., freq, time), where freq is
is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame). Fourier bins, and time is the number of window hops (n_frame).
""" """
return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length, return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
...@@ -207,11 +207,16 @@ class MelScale(torch.nn.Module): ...@@ -207,11 +207,16 @@ class MelScale(torch.nn.Module):
def forward(self, specgram): def forward(self, specgram):
r""" r"""
Args: Args:
specgram (torch.Tensor): A spectrogram STFT of dimension (channel, freq, time) specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time)
Returns: Returns:
torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
""" """
# pack batch
shape = specgram.size()
specgram = specgram.reshape(-1, shape[-2], shape[-1])
if self.fb.numel() == 0: if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate) tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
# Attributes cannot be reassigned outside __init__ so workaround # Attributes cannot be reassigned outside __init__ so workaround
...@@ -221,6 +226,10 @@ class MelScale(torch.nn.Module): ...@@ -221,6 +226,10 @@ class MelScale(torch.nn.Module):
# (channel, frequency, time).transpose(...) dot (frequency, n_mels) # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
# -> (channel, time, n_mels).transpose(...) # -> (channel, time, n_mels).transpose(...)
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
# unpack batch
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
return mel_specgram return mel_specgram
...@@ -273,10 +282,10 @@ class MelSpectrogram(torch.nn.Module): ...@@ -273,10 +282,10 @@ class MelSpectrogram(torch.nn.Module):
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time) waveform (torch.Tensor): Tensor of audio of dimension (..., time)
Returns: Returns:
torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
""" """
specgram = self.spectrogram(waveform) specgram = self.spectrogram(waveform)
mel_specgram = self.mel_scale(specgram) mel_specgram = self.mel_scale(specgram)
...@@ -332,11 +341,16 @@ class MFCC(torch.nn.Module): ...@@ -332,11 +341,16 @@ class MFCC(torch.nn.Module):
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time) waveform (torch.Tensor): Tensor of audio of dimension (..., time)
Returns: Returns:
torch.Tensor: specgram_mel_db of size (channel, ``n_mfcc``, time) torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time)
""" """
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
mel_specgram = self.MelSpectrogram(waveform) mel_specgram = self.MelSpectrogram(waveform)
if self.log_mels: if self.log_mels:
log_offset = 1e-6 log_offset = 1e-6
...@@ -346,6 +360,10 @@ class MFCC(torch.nn.Module): ...@@ -346,6 +360,10 @@ class MFCC(torch.nn.Module):
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...) # -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])
return mfcc return mfcc
...@@ -421,10 +439,10 @@ class Resample(torch.nn.Module): ...@@ -421,10 +439,10 @@ class Resample(torch.nn.Module):
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
waveform (torch.Tensor): The input signal of dimension (channel, time) waveform (torch.Tensor): The input signal of dimension (..., time)
Returns: Returns:
torch.Tensor: Output signal of dimension (channel, time) torch.Tensor: Output signal of dimension (..., time)
""" """
if self.resampling_method == 'sinc_interpolation': if self.resampling_method == 'sinc_interpolation':
return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
...@@ -471,10 +489,10 @@ class ComputeDeltas(torch.nn.Module): ...@@ -471,10 +489,10 @@ class ComputeDeltas(torch.nn.Module):
def forward(self, specgram): def forward(self, specgram):
r""" r"""
Args: Args:
specgram (torch.Tensor): Tensor of audio of dimension (channel, freq, time) specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Returns: Returns:
deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time) deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
""" """
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode) return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment