Unverified Commit c4565245 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

extend batch support (#391)

* extend batch support

closes #383

* function for batch test.

* set seed.

* adjust tolerance for griffinlim.
parent 45498f26
......@@ -103,6 +103,25 @@ class TestFunctional(unittest.TestCase):
self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))
def test_batch_griffinlim(self):
torch.random.manual_seed(42)
tensor = torch.rand((1, 201, 6))
n_fft = 400
ws = 400
hop = 200
window = torch.hann_window(ws)
power = 2
normalize = False
momentum = 0.99
n_iter = 32
length = 1000
self._test_batch(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
)
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
......@@ -126,22 +145,17 @@ class TestFunctional(unittest.TestCase):
win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time)
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)
def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
# Single then transform then batch
expected = F.detect_pitch_frequency(waveform, sample_rate)
expected = expected.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = F.detect_pitch_frequency(waveform, sample_rate)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_jit_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
......@@ -157,7 +171,6 @@ class TestFunctional(unittest.TestCase):
for data_size in self.data_sizes:
for i in range(self.number_of_trials):
# Non-batch
sound = common_utils.random_float_tensor(i, data_size)
stft = torch.stft(sound, **kwargs)
......@@ -165,14 +178,6 @@ class TestFunctional(unittest.TestCase):
self._compare_estimate(sound, estimate)
# Batch
stft = torch.stft(sound, **kwargs)
stft = stft.repeat(3, 1, 1, 1, 1)
sound = sound.repeat(3, 1, 1)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
self._compare_estimate(sound, estimate)
def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided
kwargs1 = {
......@@ -389,6 +394,16 @@ class TestFunctional(unittest.TestCase):
data_size = (2, 7, 3, 2)
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
def test_batch_istft(self):
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
self._test_batch(F.istft, stft, n_fft=4, length=4)
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
# Using a decorator here causes parametrize to fail on Python 2
if not IMPORT_LIBROSA:
......@@ -489,22 +504,63 @@ class TestFunctional(unittest.TestCase):
self.assertFalse(s)
# Convert to stereo and batch for testing purposes
freq = freq.repeat(3, 2, 1, 1)
waveform = waveform.repeat(3, 2, 1, 1)
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
def _test_batch_shape(self, functional, tensor, *args, **kwargs):
assert torch.allclose(freq, freq2, atol=1e-5)
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol
def _test_batch(self, functional):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
# Single then transform then batch
expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1)
# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = functional(waveform)
torch.random.manual_seed(42)
expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.unsqueeze(0).unsqueeze(0)
# 1-Batch then transform
tensors = tensor.unsqueeze(0).unsqueeze(0)
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
self._compare_estimate(computed, expected, **kwargs_compare)
return tensors, expected
def _test_batch(self, functional, tensor, *args, **kwargs):
tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs)
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
# 3-Batch then transform
ind = [3] + [1] * (int(tensors.dim()) - 1)
tensors = tensor.repeat(*ind)
ind = [3] + [1] * (int(expected.dim()) - 1)
expected = expected.repeat(*ind)
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
def test_torchscript_create_fb_matrix(self):
......
......@@ -381,6 +381,19 @@ class Tester(unittest.TestCase):
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786)
# Single then transform then batch
expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)
......@@ -440,6 +453,30 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_melspectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_mfcc(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected, atol=1e-5))
def test_scriptmodule_TimeStretch(self):
n_freq = 400
hop_length = 512
......
......@@ -99,7 +99,7 @@ def istft(
Args:
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. it has a size of either (..., fft_size, n_frame, 2)
column is a window. It has a size of either (..., fft_size, n_frame, 2)
n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames.
(Default: ``win_length // 4``)
......@@ -230,7 +230,7 @@ def spectrogram(
The spectrogram can be either magnitude-only or complex.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
pad (int): Two sided padding of signal
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
......@@ -242,8 +242,8 @@ def spectrogram(
normalized (bool): Whether to normalize by magnitude after stft
Returns:
torch.Tensor: Dimension (..., channel, freq, time), where channel
is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of
torch.Tensor: Dimension (..., freq, time), freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
......@@ -292,7 +292,7 @@ def griffinlim(
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Args:
specgram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (channel, freq, frames)
specgram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
where freq is ``n_fft // 2 + 1``.
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
......@@ -310,11 +310,15 @@ def griffinlim(
rand_init (bool): Initializes phase randomly if True, to zero otherwise. (Default: ``True``)
Returns:
torch.Tensor: waveform of (channel, time), where time equals the ``length`` parameter if given.
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
"""
assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
assert momentum > 0, 'momentum=%s < 0' % momentum
# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = specgram.pow(1 / power)
# randomly initialize the phase
......@@ -351,12 +355,17 @@ def griffinlim(
angles = angles.div_(complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles))
# Return the final phase estimates
return istft(specgram * angles,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
length=length)
waveform = istft(specgram * angles,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
length=length)
# unpack batch
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
return waveform
def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
......@@ -699,7 +708,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
https://en.wikipedia.org/wiki/Digital_biquad_filter
Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
b0 (float): numerator coefficient of current input, x[n]
b1 (float): numerator coefficient of input one time step ago x[n-1]
b2 (float): numerator coefficient of input two time steps ago x[n-2]
......@@ -708,7 +717,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
a2 (float): denominator coefficient of current output y[n-2]
Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`
output_waveform (torch.Tensor): Dimension of `(..., time)`
"""
device = waveform.device
......@@ -732,13 +741,13 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor
Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`
output_waveform (torch.Tensor): Dimension of `(..., time)`
"""
GAIN = 1.
......@@ -761,13 +770,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor
Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`
output_waveform (torch.Tensor): Dimension of `(..., time)`
"""
GAIN = 1.
......@@ -790,14 +799,14 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
center_freq (float): filter's central frequency
gain (float): desired gain at the boost (or attenuation) in dB
q_factor (float): https://en.wikipedia.org/wiki/Q_factor
Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`
output_waveform (torch.Tensor): Dimension of `(..., time)`
"""
w0 = 2 * math.pi * center_freq / sample_rate
A = math.exp(gain / 40.0 * math.log(10))
......@@ -886,7 +895,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
# unpack batch
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
return specgram
def compute_deltas(specgram, win_length=5, mode="replicate"):
......@@ -946,7 +955,7 @@ def gain(waveform, gain_db=1.0):
r"""Apply amplification or attenuation to the whole waveform.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).
Returns:
......@@ -999,7 +1008,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
The relationship of probabilities of results follows a bell-shaped,
or Gaussian curve, typical of dither generated by analog sources.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
probability_density_function (string): The density function of a
continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF`
......@@ -1008,6 +1017,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
Returns:
torch.Tensor: waveform dithered with TPDF
"""
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
......@@ -1047,6 +1058,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
quantised_signal_scaled = torch.round(signal_scaled_dis)
quantised_signal = quantised_signal_scaled / down_scaling
# unpack batch
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
......@@ -1056,7 +1069,7 @@ def dither(waveform, density_function="TPDF", noise_shaping=False):
particular bit-depth by eliminating nonlinear truncation distortion
(i.e. adding minimally perceived noise to mask distortion caused by quantization).
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
density_function (string): The density function of a
continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF`
......
......@@ -62,11 +62,11 @@ class Spectrogram(torch.nn.Module):
def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
Returns:
torch.Tensor: Dimension (channel, freq, time), where channel
is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of
torch.Tensor: Dimension (..., freq, time), where freq is
``n_fft // 2 + 1`` where ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
......@@ -207,11 +207,16 @@ class MelScale(torch.nn.Module):
def forward(self, specgram):
r"""
Args:
specgram (torch.Tensor): A spectrogram STFT of dimension (channel, freq, time)
specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time)
Returns:
torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time)
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
"""
# pack batch
shape = specgram.size()
specgram = specgram.reshape(-1, shape[-2], shape[-1])
if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
# Attributes cannot be reassigned outside __init__ so workaround
......@@ -221,6 +226,10 @@ class MelScale(torch.nn.Module):
# (channel, frequency, time).transpose(...) dot (frequency, n_mels)
# -> (channel, time, n_mels).transpose(...)
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
# unpack batch
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
return mel_specgram
......@@ -273,10 +282,10 @@ class MelSpectrogram(torch.nn.Module):
def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
Returns:
torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time)
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
"""
specgram = self.spectrogram(waveform)
mel_specgram = self.mel_scale(specgram)
......@@ -332,11 +341,16 @@ class MFCC(torch.nn.Module):
def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
Returns:
torch.Tensor: specgram_mel_db of size (channel, ``n_mfcc``, time)
torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time)
"""
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
mel_specgram = self.MelSpectrogram(waveform)
if self.log_mels:
log_offset = 1e-6
......@@ -346,6 +360,10 @@ class MFCC(torch.nn.Module):
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])
return mfcc
......@@ -421,10 +439,10 @@ class Resample(torch.nn.Module):
def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): The input signal of dimension (channel, time)
waveform (torch.Tensor): The input signal of dimension (..., time)
Returns:
torch.Tensor: Output signal of dimension (channel, time)
torch.Tensor: Output signal of dimension (..., time)
"""
if self.resampling_method == 'sinc_interpolation':
return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
......@@ -471,10 +489,10 @@ class ComputeDeltas(torch.nn.Module):
def forward(self, specgram):
r"""
Args:
specgram (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Returns:
deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
"""
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment