Unverified Commit f5f79d1d authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Batching for transforms (#337)

* batching for transforms.

* test for batching.

* update readme.
parent 1500d4ef
...@@ -129,7 +129,7 @@ Transforms expect and return the following dimensions. ...@@ -129,7 +129,7 @@ Transforms expect and return the following dimensions.
* `MuLawDecode`: (channel, time) -> (channel, time) * `MuLawDecode`: (channel, time) -> (channel, time)
* `Resample`: (channel, time) -> (channel, time) * `Resample`: (channel, time) -> (channel, time)
Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. optional batching and channel dimensions.
Contributing Guidelines Contributing Guidelines
----------------------- -----------------------
......
...@@ -21,6 +21,10 @@ class TestFunctional(unittest.TestCase): ...@@ -21,6 +21,10 @@ class TestFunctional(unittest.TestCase):
number_of_trials = 100 number_of_trials = 100
specgram = torch.tensor([1., 2., 3., 4.]) specgram = torch.tensor([1., 2., 3., 4.])
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3')
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length) computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
...@@ -46,6 +50,20 @@ class TestFunctional(unittest.TestCase): ...@@ -46,6 +50,20 @@ class TestFunctional(unittest.TestCase):
computed = F.compute_deltas(specgram, win_length=win_length) computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = F.detect_pitch_frequency(waveform, sample_rate)
expected = expected.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = F.detect_pitch_frequency(waveform, sample_rate)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original # trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)] sound = sound[..., :estimate.size(-1)]
...@@ -58,6 +76,8 @@ class TestFunctional(unittest.TestCase): ...@@ -58,6 +76,8 @@ class TestFunctional(unittest.TestCase):
# operation to check whether we can reconstruct signal # operation to check whether we can reconstruct signal
for data_size in self.data_sizes: for data_size in self.data_sizes:
for i in range(self.number_of_trials): for i in range(self.number_of_trials):
# Non-batch
sound = common_utils.random_float_tensor(i, data_size) sound = common_utils.random_float_tensor(i, data_size)
stft = torch.stft(sound, **kwargs) stft = torch.stft(sound, **kwargs)
...@@ -65,6 +85,14 @@ class TestFunctional(unittest.TestCase): ...@@ -65,6 +85,14 @@ class TestFunctional(unittest.TestCase):
self._compare_estimate(sound, estimate) self._compare_estimate(sound, estimate)
# Batch
stft = torch.stft(sound, **kwargs)
stft = stft.repeat(3, 1, 1, 1, 1)
sound = sound.repeat(3, 1, 1)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
self._compare_estimate(sound, estimate)
def test_istft_is_inverse_of_stft1(self): def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided # hann_window, centered, normalized, onesided
kwargs1 = { kwargs1 = {
...@@ -326,15 +354,30 @@ class TestFunctional(unittest.TestCase): ...@@ -326,15 +354,30 @@ class TestFunctional(unittest.TestCase):
for filename, freq_ref in tests: for filename, freq_ref in tests:
waveform, sample_rate = torchaudio.load(filename) waveform, sample_rate = torchaudio.load(filename)
# Convert to stereo for testing purposes
waveform = waveform.repeat(2, 1, 1)
freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate) freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
threshold = 1 threshold = 1
s = ((freq - freq_ref).abs() > threshold).sum() s = ((freq - freq_ref).abs() > threshold).sum()
self.assertFalse(s) self.assertFalse(s)
# Convert to stereo and batch for testing purposes
freq = freq.repeat(3, 2, 1, 1)
waveform = waveform.repeat(3, 2, 1, 1)
freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
assert torch.allclose(freq, freq2, atol=1e-5)
def _test_batch(self, functional):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
# Single then transform then batch
expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1)
# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = functional(waveform)
def _num_stft_bins(signal_len, fft_len, hop_length, pad): def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length return (signal_len + 2 * pad - fft_len + hop_length) // hop_length
......
...@@ -313,6 +313,45 @@ class Tester(unittest.TestCase): ...@@ -313,6 +313,45 @@ class Tester(unittest.TestCase):
computed = transform(specgram) computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)
# Single then transform then batch
expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_mulaw(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
# Single then transform then batch
waveform_encoded = transforms.MuLawEncoding()(waveform)
expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = transforms.MuLawEncoding()(waveform_batched)
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
# Single then transform then batch
waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
computed = transforms.MuLawDecoding()(computed)
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_spectrogram(self): def test_batch_spectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) waveform, sample_rate = torchaudio.load(self.test_filepath)
......
...@@ -114,16 +114,20 @@ def istft( ...@@ -114,16 +114,20 @@ def istft(
original signal length). (Default: whole signal) original signal length). (Default: whole signal)
Returns: Returns:
torch.Tensor: Least squares estimation of the original signal of size torch.Tensor: Least squares estimation of the original signal of size (..., signal_length)
(channel, signal_length) or (signal_length)
""" """
stft_matrix_dim = stft_matrix.dim() stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim) assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim)
assert stft_matrix.nelement() > 0
if stft_matrix_dim == 3: if stft_matrix_dim == 3:
# add a channel dimension # add a channel dimension
stft_matrix = stft_matrix.unsqueeze(0) stft_matrix = stft_matrix.unsqueeze(0)
# pack batch
shape = stft_matrix.size()
stft_matrix = stft_matrix.reshape(-1, *shape[-3:])
dtype = stft_matrix.dtype dtype = stft_matrix.dtype
device = stft_matrix.device device = stft_matrix.device
fft_size = stft_matrix.size(1) fft_size = stft_matrix.size(1)
...@@ -208,8 +212,12 @@ def istft( ...@@ -208,8 +212,12 @@ def istft(
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
# unpack batch
y = y.reshape(shape[:-3] + y.shape[-1:])
if stft_matrix_dim == 3: # remove the channel dimension if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0) y = y.squeeze(0)
return y return y
...@@ -514,14 +522,14 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ...@@ -514,14 +522,14 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
dtype=complex_specgrams.dtype) dtype=complex_specgrams.dtype)
alphas = time_steps % 1.0 alphas = time_steps % 1.0
phase_0 = angle(complex_specgrams[:, :, :1]) phase_0 = angle(complex_specgrams[..., :1, :])
# Time Padding # Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])
# (new_bins, freq, 2) # (new_bins, freq, 2)
complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()] complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long())
complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()] complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long())
angle_0 = angle(complex_specgrams_0) angle_0 = angle(complex_specgrams_0)
angle_1 = angle(complex_specgrams_1) angle_1 = angle(complex_specgrams_1)
...@@ -534,7 +542,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ...@@ -534,7 +542,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
# Compute Phase Accum # Compute Phase Accum
phase = phase + phase_advance phase = phase + phase_advance
phase = torch.cat([phase_0, phase[:, :, :-1]], dim=-1) phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
phase_acc = torch.cumsum(phase, -1) phase_acc = torch.cumsum(phase, -1)
mag = alphas * norm_1 + (1 - alphas) * norm_0 mag = alphas * norm_1 + (1 - alphas) * norm_0
...@@ -554,7 +562,7 @@ def lfilter(waveform, a_coeffs, b_coeffs): ...@@ -554,7 +562,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
Performs an IIR filter by evaluating difference equation. Performs an IIR filter by evaluating difference equation.
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`. Must be normalized to -1 to 1. waveform (torch.Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`. a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`. Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`.
Must be same size as b_coeffs (pad with 0's as necessary). Must be same size as b_coeffs (pad with 0's as necessary).
...@@ -563,10 +571,16 @@ def lfilter(waveform, a_coeffs, b_coeffs): ...@@ -563,10 +571,16 @@ def lfilter(waveform, a_coeffs, b_coeffs):
Must be same size as a_coeffs (pad with 0's as necessary). Must be same size as a_coeffs (pad with 0's as necessary).
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`. Output will be clipped to -1 to 1. output_waveform (torch.Tensor): Dimension of `(..., time)`. Output will be clipped to -1 to 1.
""" """
dim = waveform.dim()
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
assert(a_coeffs.size(0) == b_coeffs.size(0)) assert(a_coeffs.size(0) == b_coeffs.size(0))
assert(len(waveform.size()) == 2) assert(len(waveform.size()) == 2)
assert(waveform.device == a_coeffs.device) assert(waveform.device == a_coeffs.device)
...@@ -606,7 +620,14 @@ def lfilter(waveform, a_coeffs, b_coeffs): ...@@ -606,7 +620,14 @@ def lfilter(waveform, a_coeffs, b_coeffs):
padded_output_waveform[:, i_sample + n_order - 1] = o0 padded_output_waveform[:, i_sample + n_order - 1] = o0
return torch.min(ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):])) output = torch.min(
ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):])
)
# unpack batch
output = output.reshape(shape[:-1] + output.shape[-1:])
return output
@torch.jit.script @torch.jit.script
...@@ -817,12 +838,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): ...@@ -817,12 +838,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
:math:`N` is (`win_length`-1)//2. :math:`N` is (`win_length`-1)//2.
Args: Args:
specgram (torch.Tensor): Tensor of audio of dimension (channel, freq, time) specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
win_length (int): The window length used for computing delta win_length (int): The window length used for computing delta
mode (str): Mode parameter passed to padding mode (str): Mode parameter passed to padding
Returns: Returns:
deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time) deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Example Example
>>> specgram = torch.randn(1, 40, 1000) >>> specgram = torch.randn(1, 40, 1000)
...@@ -830,9 +851,11 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): ...@@ -830,9 +851,11 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
>>> delta2 = compute_deltas(delta) >>> delta2 = compute_deltas(delta)
""" """
# pack batch
shape = specgram.size()
specgram = specgram.reshape(1, -1, shape[-1])
assert win_length >= 3 assert win_length >= 3
assert specgram.dim() == 3
assert not specgram.shape[1] % specgram.shape[0]
n = (win_length - 1) // 2 n = (win_length - 1) // 2
...@@ -844,12 +867,15 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): ...@@ -844,12 +867,15 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
kernel = ( kernel = (
torch torch
.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype) .arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype)
.repeat(specgram.shape[1], specgram.shape[0], 1) .repeat(specgram.shape[1], 1, 1)
) )
return torch.nn.functional.conv1d( output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
specgram, kernel, groups=specgram.shape[1] // specgram.shape[0]
) / denom # unpack batch
output = output.reshape(shape)
return output
@torch.jit.script @torch.jit.script
...@@ -982,16 +1008,22 @@ def detect_pitch_frequency( ...@@ -982,16 +1008,22 @@ def detect_pitch_frequency(
It is implemented using normalized cross-correlation function and median smoothing. It is implemented using normalized cross-correlation function and median smoothing.
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, freq, time) waveform (torch.Tensor): Tensor of audio of dimension (..., freq, time)
sample_rate (int): The sample rate of the waveform (Hz) sample_rate (int): The sample rate of the waveform (Hz)
win_length (int): The window length for median smoothing (in number of frames) win_length (int): The window length for median smoothing (in number of frames)
freq_low (int): Lowest frequency that can be detected (Hz) freq_low (int): Lowest frequency that can be detected (Hz)
freq_high (int): Highest frequency that can be detected (Hz) freq_high (int): Highest frequency that can be detected (Hz)
Returns: Returns:
freq (torch.Tensor): Tensor of audio of dimension (channel, frame) freq (torch.Tensor): Tensor of audio of dimension (..., frame)
""" """
dim = waveform.dim()
# pack batch
shape = waveform.size()
waveform = waveform.reshape([-1] + shape[-1:])
nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low) nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
indices = _find_max_per_frame(nccf, sample_rate, freq_high) indices = _find_max_per_frame(nccf, sample_rate, freq_high)
indices = _median_smoothing(indices, win_length) indices = _median_smoothing(indices, win_length)
...@@ -1000,4 +1032,7 @@ def detect_pitch_frequency( ...@@ -1000,4 +1032,7 @@ def detect_pitch_frequency(
EPSILON = 10 ** (-9) EPSILON = 10 ** (-9)
freq = sample_rate / (EPSILON + indices.to(torch.float)) freq = sample_rate / (EPSILON + indices.to(torch.float))
# unpack batch
freq = freq.reshape(shape[:-1] + freq.shape[-1:])
return freq return freq
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment