Commit e3024341 authored by engineerchuan's avatar engineerchuan Committed by Vincent QB
Browse files

Fix lfilter for GPU machine (#291)

* Ensure that lfilter works with GPU
* Add option to load on different device
* Debugging tests of lfilter on GPU
* allowing arbitrary types into lfilter
parent 401e7aee
...@@ -12,7 +12,7 @@ import time ...@@ -12,7 +12,7 @@ import time
class TestFunctionalFiltering(unittest.TestCase): class TestFunctionalFiltering(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir() test_dirpath, test_dir = common_utils.create_temp_assets_dir()
def test_lfilter_basic(self): def _test_lfilter_basic(self, dtype, device):
""" """
Create a very basic signal, Create a very basic signal,
Then make a simple 4th order delay Then make a simple 4th order delay
...@@ -20,16 +20,27 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -20,16 +20,27 @@ class TestFunctionalFiltering(unittest.TestCase):
""" """
torch.random.manual_seed(42) torch.random.manual_seed(42)
waveform = torch.rand(2, 44100 * 10) waveform = torch.rand(2, 44100 * 1, dtype=dtype, device=device)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=torch.float32) b_coeffs = torch.tensor([0, 0, 0, 1], dtype=dtype, device=device)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=torch.float32) a_coeffs = torch.tensor([1, 0, 0, 0], dtype=dtype, device=device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert torch.allclose( assert torch.allclose(waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5)
waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5
)
def test_lfilter(self): def test_lfilter_basic(self):
self._test_lfilter_basic(torch.float32, torch.device("cpu"))
def test_lfilter_basic_double(self):
self._test_lfilter_basic(torch.float64, torch.device("cpu"))
def test_lfilter_basic_gpu(self):
if torch.cuda.is_available():
self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))
else:
print("skipping GPU test for lfilter_basic because device not available")
pass
def _test_lfilter(self, waveform, device):
""" """
Design an IIR lowpass filter using scipy.signal filter design Design an IIR lowpass filter using scipy.signal filter design
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
...@@ -48,7 +59,8 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -48,7 +59,8 @@ class TestFunctionalFiltering(unittest.TestCase):
0.00841964, 0.00841964,
-0.0051152, -0.0051152,
0.00299893, 0.00299893,
] ],
device=device,
) )
a_coeffs = torch.tensor( a_coeffs = torch.tensor(
[ [
...@@ -59,16 +71,33 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -59,16 +71,33 @@ class TestFunctionalFiltering(unittest.TestCase):
8.49018171, 8.49018171,
-3.3066882, -3.3066882,
0.56088705, 0.56088705,
] ],
device=device,
) )
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert len(output_waveform.size()) == 2 assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0) assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1) assert output_waveform.size(1) == waveform.size(1)
def test_lfilter(self):
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
waveform, _ = torchaudio.load(filepath, normalization=True)
self._test_lfilter(waveform, torch.device("cpu"))
def test_lfilter_gpu(self):
if torch.cuda.is_available():
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
waveform, _ = torchaudio.load(filepath, normalization=True)
cuda0 = torch.device("cuda:0")
cuda_waveform = waveform.cuda(device=cuda0)
self._test_lfilter(cuda_waveform, cuda0)
else:
print("skipping GPU test for lfilter because device not available")
pass
def test_lowpass(self): def test_lowpass(self):
""" """
...@@ -77,17 +106,13 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -77,17 +106,13 @@ class TestFunctionalFiltering(unittest.TestCase):
CUTOFF_FREQ = 3000 CUTOFF_FREQ = 3000
noise_filepath = os.path.join( noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
self.test_dirpath, "assets", "whitenoise.mp3"
)
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("lowpass", [CUTOFF_FREQ]) E.append_effect_to_chain("lowpass", [CUTOFF_FREQ])
sox_output_waveform, sr = E.sox_build_flow_effects() sox_output_waveform, sr = E.sox_build_flow_effects()
waveform, sample_rate = torchaudio.load( waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
noise_filepath, normalization=True
)
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ) output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
...@@ -99,17 +124,13 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -99,17 +124,13 @@ class TestFunctionalFiltering(unittest.TestCase):
CUTOFF_FREQ = 2000 CUTOFF_FREQ = 2000
noise_filepath = os.path.join( noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
self.test_dirpath, "assets", "whitenoise.mp3"
)
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("highpass", [CUTOFF_FREQ]) E.append_effect_to_chain("highpass", [CUTOFF_FREQ])
sox_output_waveform, sr = E.sox_build_flow_effects() sox_output_waveform, sr = E.sox_build_flow_effects()
waveform, sample_rate = torchaudio.load( waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
noise_filepath, normalization=True
)
output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ) output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
# TBD - this fails at the 1e-4 level, debug why # TBD - this fails at the 1e-4 level, debug why
......
...@@ -525,21 +525,25 @@ def lfilter(waveform, a_coeffs, b_coeffs): ...@@ -525,21 +525,25 @@ def lfilter(waveform, a_coeffs, b_coeffs):
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`. Output will be clipped to -1 to 1. output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`. Output will be clipped to -1 to 1.
Will be on the same device as the inputs.
""" """
assert(waveform.dtype == torch.float32)
assert(a_coeffs.size(0) == b_coeffs.size(0)) assert(a_coeffs.size(0) == b_coeffs.size(0))
assert(len(waveform.size()) == 2) assert(len(waveform.size()) == 2)
assert(waveform.device == a_coeffs.device)
assert(b_coeffs.device == a_coeffs.device)
device = waveform.device
dtype = waveform.dtype
n_channels, n_frames = waveform.size() n_channels, n_frames = waveform.size()
n_order = a_coeffs.size(0) n_order = a_coeffs.size(0)
assert(n_order > 0) assert(n_order > 0)
# Pad the input and create output # Pad the input and create output
padded_waveform = torch.zeros(n_channels, n_frames + n_order - 1) padded_waveform = torch.zeros(n_channels, n_frames + n_order - 1, dtype=dtype, device=device)
padded_waveform[:, (n_order - 1):] = waveform padded_waveform[:, (n_order - 1):] = waveform
padded_output_waveform = torch.zeros(n_channels, n_frames + n_order - 1) padded_output_waveform = torch.zeros(n_channels, n_frames + n_order - 1, dtype=dtype, device=device)
# Set up the coefficients matrix # Set up the coefficients matrix
# Flip order, repeat, and transpose # Flip order, repeat, and transpose
...@@ -547,12 +551,12 @@ def lfilter(waveform, a_coeffs, b_coeffs): ...@@ -547,12 +551,12 @@ def lfilter(waveform, a_coeffs, b_coeffs):
b_coeffs_filled = b_coeffs.flip(0).repeat(n_channels, 1).t() b_coeffs_filled = b_coeffs.flip(0).repeat(n_channels, 1).t()
# Set up a few other utilities # Set up a few other utilities
a0_repeated = torch.ones(n_channels) * a_coeffs[0] a0_repeated = torch.ones(n_channels, dtype=dtype, device=device) * a_coeffs[0]
ones = torch.ones(n_channels, n_frames) ones = torch.ones(n_channels, n_frames, dtype=dtype, device=device)
for i_frame in range(n_frames): for i_frame in range(n_frames):
o0 = torch.zeros(n_channels) o0 = torch.zeros(n_channels, dtype=dtype, device=device)
windowed_input_signal = padded_waveform[:, i_frame:(i_frame + n_order)] windowed_input_signal = padded_waveform[:, i_frame:(i_frame + n_order)]
windowed_output_signal = padded_output_waveform[:, i_frame:(i_frame + n_order)] windowed_output_signal = padded_output_waveform[:, i_frame:(i_frame + n_order)]
...@@ -585,10 +589,13 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): ...@@ -585,10 +589,13 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)` output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`
""" """
assert(waveform.dtype == torch.float32) device = waveform.device
dtype = waveform.dtype
output_waveform = lfilter( output_waveform = lfilter(
waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]) waveform,
torch.tensor([a0, a1, a2], dtype=dtype, device=device),
torch.tensor([b0, b1, b2], dtype=dtype, device=device)
) )
return output_waveform return output_waveform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment