Unverified Commit 401e7aee authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Compute deltas (#268)

* compute deltas.
* multichannel, and random test.
* documentation.
* feedback. changing name of window to win_length.
* passing padding mode.
parent 8273c3f4
......@@ -319,5 +319,6 @@ class Test_Kaldi(unittest.TestCase):
single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2)
self.assertTrue(torch.allclose(multi_sound_sampled[i, :], single_channel_sampled, rtol=1e-4))
if __name__ == '__main__':
unittest.main()
......@@ -18,6 +18,32 @@ if IMPORT_LIBROSA:
class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)]
number_of_trials = 100
specgram = torch.tensor([1., 2., 3., 4.])
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
torch.testing.assert_allclose(computed, expected, atol=atol, rtol=rtol)
def test_compute_deltas_onechannel(self):
specgram = self.specgram.unsqueeze(0).unsqueeze(0)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
self._test_compute_deltas(specgram, expected)
def test_compute_deltas_twochannel(self):
specgram = self.specgram.repeat(1, 2, 1)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
self._test_compute_deltas(specgram, expected)
def test_compute_deltas_randn(self):
channel = 13
n_mfcc = channel * 3
time = 1021
win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time)
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
......
......@@ -4,8 +4,9 @@ import os
import torch
import torchaudio
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
import torchaudio.transforms as transforms
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
import unittest
import common_utils
......@@ -281,5 +282,37 @@ class Tester(unittest.TestCase):
# we expect the downsampled signal to have half as many samples
self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
def test_compute_deltas(self):
channel = 13
n_mfcc = channel * 3
time = 1021
win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time)
transform = transforms.ComputeDeltas(win_length=win_length)
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
channel = 13
n_mfcc = channel * 3
time = 1021
win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time)
transform = transforms.ComputeDeltas(win_length=win_length)
computed_transform = transform(specgram)
computed_functional = F.compute_deltas(specgram, win_length=win_length)
torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)
def test_compute_deltas_twochannel(self):
specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
transform = transforms.ComputeDeltas()
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,7 @@ __all__ = [
"biquad",
]
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
@torch.jit.ignore
def _stft(
......@@ -652,3 +653,50 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
a1 = -2 * math.cos(w0)
a2 = 1 - alpha
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def compute_deltas(specgram, win_length=5, mode="replicate"):
# type: (Tensor, int, str) -> Tensor
r"""Compute delta coefficients of a tensor, usually a spectrogram:
.. math::
d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N} n^2}
where :math:`d_t` is the deltas at time :math:`t`,
:math:`c_t` is the spectrogram coeffcients at time :math:`t`,
:math:`N` is (`win_length`-1)//2.
Args:
specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time)
win_length (int): The window length used for computing delta
mode (str): Mode parameter passed to padding
Returns:
deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time)
Example
>>> specgram = torch.randn(1, 40, 1000)
>>> delta = compute_deltas(specgram)
>>> delta2 = compute_deltas(delta)
"""
assert win_length >= 3
assert specgram.dim() == 3
assert not specgram.shape[1] % specgram.shape[0]
n = (win_length - 1) // 2
# twice sum of integer squared
denom = n * (n + 1) * (2 * n + 1) / 3
specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)
kernel = (
torch
.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype)
.repeat(specgram.shape[1], specgram.shape[0], 1)
)
return torch.nn.functional.conv1d(
specgram, kernel, groups=specgram.shape[1] // specgram.shape[0]
) / denom
......@@ -365,3 +365,30 @@ class Resample(torch.nn.Module):
return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
class ComputeDeltas(torch.jit.ScriptModule):
r"""Compute delta coefficients of a tensor, usually a spectrogram.
See `torchaudio.functional.compute_deltas` for more details.
Args:
win_length (int): The window length used for computing delta.
"""
__constants__ = ['win_length']
def __init__(self, win_length=5, mode="replicate"):
super(ComputeDeltas, self).__init__()
self.win_length = win_length
self.mode = torch.jit.Attribute(mode, str)
@torch.jit.script_method
def forward(self, specgram):
r"""
Args:
specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time)
Returns:
deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time)
"""
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment