Commit e3024341 authored by engineerchuan's avatar engineerchuan Committed by Vincent QB
Browse files

Fix lfilter for GPU machine (#291)

* Ensure that lfilter works with GPU
* Add option to load on different device
* Debugging tests of lfilter on GPU
* allowing arbitrary types into lfilter
parent 401e7aee
......@@ -12,7 +12,7 @@ import time
class TestFunctionalFiltering(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
def test_lfilter_basic(self):
def _test_lfilter_basic(self, dtype, device):
"""
Create a very basic signal,
Then make a simple 4th order delay
......@@ -20,16 +20,27 @@ class TestFunctionalFiltering(unittest.TestCase):
"""
torch.random.manual_seed(42)
waveform = torch.rand(2, 44100 * 10)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=torch.float32)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=torch.float32)
waveform = torch.rand(2, 44100 * 1, dtype=dtype, device=device)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=dtype, device=device)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=dtype, device=device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert torch.allclose(
waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5
)
assert torch.allclose(waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5)
def test_lfilter(self):
def test_lfilter_basic(self):
self._test_lfilter_basic(torch.float32, torch.device("cpu"))
def test_lfilter_basic_double(self):
self._test_lfilter_basic(torch.float64, torch.device("cpu"))
def test_lfilter_basic_gpu(self):
if torch.cuda.is_available():
self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))
else:
print("skipping GPU test for lfilter_basic because device not available")
pass
def _test_lfilter(self, waveform, device):
"""
Design an IIR lowpass filter using scipy.signal filter design
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
......@@ -48,7 +59,8 @@ class TestFunctionalFiltering(unittest.TestCase):
0.00841964,
-0.0051152,
0.00299893,
]
],
device=device,
)
a_coeffs = torch.tensor(
[
......@@ -59,16 +71,33 @@ class TestFunctionalFiltering(unittest.TestCase):
8.49018171,
-3.3066882,
0.56088705,
]
],
device=device,
)
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1)
def test_lfilter(self):
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
waveform, _ = torchaudio.load(filepath, normalization=True)
self._test_lfilter(waveform, torch.device("cpu"))
def test_lfilter_gpu(self):
if torch.cuda.is_available():
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
waveform, _ = torchaudio.load(filepath, normalization=True)
cuda0 = torch.device("cuda:0")
cuda_waveform = waveform.cuda(device=cuda0)
self._test_lfilter(cuda_waveform, cuda0)
else:
print("skipping GPU test for lfilter because device not available")
pass
def test_lowpass(self):
"""
......@@ -77,17 +106,13 @@ class TestFunctionalFiltering(unittest.TestCase):
CUTOFF_FREQ = 3000
noise_filepath = os.path.join(
self.test_dirpath, "assets", "whitenoise.mp3"
)
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("lowpass", [CUTOFF_FREQ])
sox_output_waveform, sr = E.sox_build_flow_effects()
waveform, sample_rate = torchaudio.load(
noise_filepath, normalization=True
)
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
......@@ -99,17 +124,13 @@ class TestFunctionalFiltering(unittest.TestCase):
CUTOFF_FREQ = 2000
noise_filepath = os.path.join(
self.test_dirpath, "assets", "whitenoise.mp3"
)
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("highpass", [CUTOFF_FREQ])
sox_output_waveform, sr = E.sox_build_flow_effects()
waveform, sample_rate = torchaudio.load(
noise_filepath, normalization=True
)
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
# TBD - this fails at the 1e-4 level, debug why
......
......@@ -525,21 +525,25 @@ def lfilter(waveform, a_coeffs, b_coeffs):
Returns:
output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`. Output will be clipped to -1 to 1.
Will be on the same device as the inputs.
"""
assert(waveform.dtype == torch.float32)
assert(a_coeffs.size(0) == b_coeffs.size(0))
assert(len(waveform.size()) == 2)
assert(waveform.device == a_coeffs.device)
assert(b_coeffs.device == a_coeffs.device)
device = waveform.device
dtype = waveform.dtype
n_channels, n_frames = waveform.size()
n_order = a_coeffs.size(0)
assert(n_order > 0)
# Pad the input and create output
padded_waveform = torch.zeros(n_channels, n_frames + n_order - 1)
padded_waveform = torch.zeros(n_channels, n_frames + n_order - 1, dtype=dtype, device=device)
padded_waveform[:, (n_order - 1):] = waveform
padded_output_waveform = torch.zeros(n_channels, n_frames + n_order - 1)
padded_output_waveform = torch.zeros(n_channels, n_frames + n_order - 1, dtype=dtype, device=device)
# Set up the coefficients matrix
# Flip order, repeat, and transpose
......@@ -547,12 +551,12 @@ def lfilter(waveform, a_coeffs, b_coeffs):
b_coeffs_filled = b_coeffs.flip(0).repeat(n_channels, 1).t()
# Set up a few other utilities
a0_repeated = torch.ones(n_channels) * a_coeffs[0]
ones = torch.ones(n_channels, n_frames)
a0_repeated = torch.ones(n_channels, dtype=dtype, device=device) * a_coeffs[0]
ones = torch.ones(n_channels, n_frames, dtype=dtype, device=device)
for i_frame in range(n_frames):
o0 = torch.zeros(n_channels)
o0 = torch.zeros(n_channels, dtype=dtype, device=device)
windowed_input_signal = padded_waveform[:, i_frame:(i_frame + n_order)]
windowed_output_signal = padded_output_waveform[:, i_frame:(i_frame + n_order)]
......@@ -585,10 +589,13 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`
"""
assert(waveform.dtype == torch.float32)
device = waveform.device
dtype = waveform.dtype
output_waveform = lfilter(
waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
waveform,
torch.tensor([a0, a1, a2], dtype=dtype, device=device),
torch.tensor([b0, b1, b2], dtype=dtype, device=device)
)
return output_waveform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment