Unverified Commit f5f79d1d authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Batching for transforms (#337)

* batching for transforms.

* test for batching.

* update readme.
parent 1500d4ef
......@@ -129,7 +129,7 @@ Transforms expect and return the following dimensions.
* `MuLawDecode`: (channel, time) -> (channel, time)
* `Resample`: (channel, time) -> (channel, time)
Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase.
Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. optional batching and channel dimensions.
Contributing Guidelines
-----------------------
......
......@@ -21,6 +21,10 @@ class TestFunctional(unittest.TestCase):
number_of_trials = 100
specgram = torch.tensor([1., 2., 3., 4.])
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3')
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
......@@ -46,6 +50,20 @@ class TestFunctional(unittest.TestCase):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = F.detect_pitch_frequency(waveform, sample_rate)
expected = expected.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = F.detect_pitch_frequency(waveform, sample_rate)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)]
......@@ -58,6 +76,8 @@ class TestFunctional(unittest.TestCase):
# operation to check whether we can reconstruct signal
for data_size in self.data_sizes:
for i in range(self.number_of_trials):
# Non-batch
sound = common_utils.random_float_tensor(i, data_size)
stft = torch.stft(sound, **kwargs)
......@@ -65,6 +85,14 @@ class TestFunctional(unittest.TestCase):
self._compare_estimate(sound, estimate)
# Batch
stft = torch.stft(sound, **kwargs)
stft = stft.repeat(3, 1, 1, 1, 1)
sound = sound.repeat(3, 1, 1)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
self._compare_estimate(sound, estimate)
def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided
kwargs1 = {
......@@ -326,15 +354,30 @@ class TestFunctional(unittest.TestCase):
for filename, freq_ref in tests:
waveform, sample_rate = torchaudio.load(filename)
# Convert to stereo for testing purposes
waveform = waveform.repeat(2, 1, 1)
freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
threshold = 1
s = ((freq - freq_ref).abs() > threshold).sum()
self.assertFalse(s)
# Convert to stereo and batch for testing purposes
freq = freq.repeat(3, 2, 1, 1)
waveform = waveform.repeat(3, 2, 1, 1)
freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
assert torch.allclose(freq, freq2, atol=1e-5)
def _test_batch(self, functional):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
# Single then transform then batch
expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1)
# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = functional(waveform)
def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length
......
......@@ -313,6 +313,45 @@ class Tester(unittest.TestCase):
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)
# Single then transform then batch
expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_mulaw(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
# Single then transform then batch
waveform_encoded = transforms.MuLawEncoding()(waveform)
expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = transforms.MuLawEncoding()(waveform_batched)
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
# Single then transform then batch
waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
computed = transforms.MuLawDecoding()(computed)
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_spectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
......
......@@ -114,16 +114,20 @@ def istft(
original signal length). (Default: whole signal)
Returns:
torch.Tensor: Least squares estimation of the original signal of size
(channel, signal_length) or (signal_length)
torch.Tensor: Least squares estimation of the original signal of size (..., signal_length)
"""
stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim)
assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim)
assert stft_matrix.nelement() > 0
if stft_matrix_dim == 3:
# add a channel dimension
stft_matrix = stft_matrix.unsqueeze(0)
# pack batch
shape = stft_matrix.size()
stft_matrix = stft_matrix.reshape(-1, *shape[-3:])
dtype = stft_matrix.dtype
device = stft_matrix.device
fft_size = stft_matrix.size(1)
......@@ -208,8 +212,12 @@ def istft(
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
# unpack batch
y = y.reshape(shape[:-3] + y.shape[-1:])
if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0)
return y
......@@ -514,14 +522,14 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
dtype=complex_specgrams.dtype)
alphas = time_steps % 1.0
phase_0 = angle(complex_specgrams[:, :, :1])
phase_0 = angle(complex_specgrams[..., :1, :])
# Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])
# (new_bins, freq, 2)
complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()]
complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()]
complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long())
complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long())
angle_0 = angle(complex_specgrams_0)
angle_1 = angle(complex_specgrams_1)
......@@ -534,7 +542,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
# Compute Phase Accum
phase = phase + phase_advance
phase = torch.cat([phase_0, phase[:, :, :-1]], dim=-1)
phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
phase_acc = torch.cumsum(phase, -1)
mag = alphas * norm_1 + (1 - alphas) * norm_0
......@@ -554,7 +562,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
Performs an IIR filter by evaluating difference equation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`. Must be normalized to -1 to 1.
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`.
Must be same size as b_coeffs (pad with 0's as necessary).
......@@ -563,10 +571,16 @@ def lfilter(waveform, a_coeffs, b_coeffs):
Must be same size as a_coeffs (pad with 0's as necessary).
Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`. Output will be clipped to -1 to 1.
output_waveform (torch.Tensor): Dimension of `(..., time)`. Output will be clipped to -1 to 1.
"""
dim = waveform.dim()
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
assert(a_coeffs.size(0) == b_coeffs.size(0))
assert(len(waveform.size()) == 2)
assert(waveform.device == a_coeffs.device)
......@@ -606,7 +620,14 @@ def lfilter(waveform, a_coeffs, b_coeffs):
padded_output_waveform[:, i_sample + n_order - 1] = o0
return torch.min(ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):]))
output = torch.min(
ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):])
)
# unpack batch
output = output.reshape(shape[:-1] + output.shape[-1:])
return output
@torch.jit.script
......@@ -817,12 +838,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
:math:`N` is (`win_length`-1)//2.
Args:
specgram (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
win_length (int): The window length used for computing delta
mode (str): Mode parameter passed to padding
Returns:
deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Example
>>> specgram = torch.randn(1, 40, 1000)
......@@ -830,9 +851,11 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
>>> delta2 = compute_deltas(delta)
"""
# pack batch
shape = specgram.size()
specgram = specgram.reshape(1, -1, shape[-1])
assert win_length >= 3
assert specgram.dim() == 3
assert not specgram.shape[1] % specgram.shape[0]
n = (win_length - 1) // 2
......@@ -844,12 +867,15 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
kernel = (
torch
.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype)
.repeat(specgram.shape[1], specgram.shape[0], 1)
.repeat(specgram.shape[1], 1, 1)
)
return torch.nn.functional.conv1d(
specgram, kernel, groups=specgram.shape[1] // specgram.shape[0]
) / denom
output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
# unpack batch
output = output.reshape(shape)
return output
@torch.jit.script
......@@ -982,16 +1008,22 @@ def detect_pitch_frequency(
It is implemented using normalized cross-correlation function and median smoothing.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., freq, time)
sample_rate (int): The sample rate of the waveform (Hz)
win_length (int): The window length for median smoothing (in number of frames)
freq_low (int): Lowest frequency that can be detected (Hz)
freq_high (int): Highest frequency that can be detected (Hz)
Returns:
freq (torch.Tensor): Tensor of audio of dimension (channel, frame)
freq (torch.Tensor): Tensor of audio of dimension (..., frame)
"""
dim = waveform.dim()
# pack batch
shape = waveform.size()
waveform = waveform.reshape([-1] + shape[-1:])
nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
indices = _find_max_per_frame(nccf, sample_rate, freq_high)
indices = _median_smoothing(indices, win_length)
......@@ -1000,4 +1032,7 @@ def detect_pitch_frequency(
EPSILON = 10 ** (-9)
freq = sample_rate / (EPSILON + indices.to(torch.float))
# unpack batch
freq = freq.reshape(shape[:-1] + freq.shape[-1:])
return freq
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment