Unverified Commit 52decd2a authored by chin yun yu's avatar chin yun yu Committed by GitHub
Browse files

add autograd to biquad filters (#1400)

parent e4a0bd2c
...@@ -5,7 +5,7 @@ from torchaudio_unittest import common_utils ...@@ -5,7 +5,7 @@ from torchaudio_unittest import common_utils
class Autograd(common_utils.TestBaseMixin): class Autograd(common_utils.TestBaseMixin):
def test_x_grad(self): def test_lfilter_x(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
...@@ -13,7 +13,7 @@ class Autograd(common_utils.TestBaseMixin): ...@@ -13,7 +13,7 @@ class Autograd(common_utils.TestBaseMixin):
x.requires_grad = True x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
def test_a_grad(self): def test_lfilter_a(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
...@@ -21,7 +21,7 @@ class Autograd(common_utils.TestBaseMixin): ...@@ -21,7 +21,7 @@ class Autograd(common_utils.TestBaseMixin):
a.requires_grad = True a.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
def test_b_grad(self): def test_lfilter_b(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
...@@ -29,7 +29,7 @@ class Autograd(common_utils.TestBaseMixin): ...@@ -29,7 +29,7 @@ class Autograd(common_utils.TestBaseMixin):
b.requires_grad = True b.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
def test_all_grad(self): def test_lfilter_all_inputs(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
...@@ -38,3 +38,101 @@ class Autograd(common_utils.TestBaseMixin): ...@@ -38,3 +38,101 @@ class Autograd(common_utils.TestBaseMixin):
a.requires_grad = True a.requires_grad = True
x.requires_grad = True x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
def test_biquad(self):
torch.random.manual_seed(2434)
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device, requires_grad=True)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]), eps=1e-10)
def test_band_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.band_biquad, (x, sr, central_freq, Q))
def test_band_biquad_with_noise(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.band_biquad, (x, sr, central_freq, Q, True))
def test_bass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(100, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bass_biquad, (x, sr, gain, central_freq, Q))
def test_treble_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(3000, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.treble_biquad, (x, sr, gain, central_freq, Q))
def test_allpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.allpass_biquad, (x, sr, central_freq, Q))
def test_lowpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
def test_highpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.highpass_biquad, (x, sr, cutoff_freq, Q))
def test_bandpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q))
def test_bandpass_biquad_with_const_skirt_gain(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q, True))
def test_equalizer_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
def test_bandreject_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandreject_biquad, (x, sr, central_freq, Q))
...@@ -73,8 +73,8 @@ def allpass_biquad( ...@@ -73,8 +73,8 @@ def allpass_biquad(
Args: Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)` waveform(torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz) central_freq (float or torch.Tensor): central frequency (in Hz)
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns: Returns:
Tensor: Waveform of dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
...@@ -83,14 +83,20 @@ def allpass_biquad( ...@@ -83,14 +83,20 @@ def allpass_biquad(
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
""" """
dtype = waveform.dtype
device = waveform.device
central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
Q = torch.as_tensor(Q, dtype=dtype, device=device)
w0 = 2 * math.pi * central_freq / sample_rate w0 = 2 * math.pi * central_freq / sample_rate
alpha = math.sin(w0) / 2 / Q
alpha = torch.sin(w0) / 2 / Q
b0 = 1 - alpha b0 = 1 - alpha
b1 = -2 * math.cos(w0) b1 = -2 * torch.cos(w0)
b2 = 1 + alpha b2 = 1 + alpha
a0 = 1 + alpha a0 = 1 + alpha
a1 = -2 * math.cos(w0) a1 = -2 * torch.cos(w0)
a2 = 1 - alpha a2 = 1 - alpha
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
...@@ -107,8 +113,8 @@ def band_biquad( ...@@ -107,8 +113,8 @@ def band_biquad(
Args: Args:
waveform (Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz) central_freq (float or torch.Tensor): central frequency (in Hz)
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``). Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``).
noise (bool, optional) : If ``True``, uses the alternate mode for un-pitched audio (e.g. percussion). noise (bool, optional) : If ``True``, uses the alternate mode for un-pitched audio (e.g. percussion).
If ``False``, uses mode oriented to pitched audio, i.e. voice, singing, If ``False``, uses mode oriented to pitched audio, i.e. voice, singing,
or instrumental music (Default: ``False``). or instrumental music (Default: ``False``).
...@@ -120,18 +126,23 @@ def band_biquad( ...@@ -120,18 +126,23 @@ def band_biquad(
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
""" """
dtype = waveform.dtype
device = waveform.device
central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
Q = torch.as_tensor(Q, dtype=dtype, device=device)
w0 = 2 * math.pi * central_freq / sample_rate w0 = 2 * math.pi * central_freq / sample_rate
bw_Hz = central_freq / Q bw_Hz = central_freq / Q
a0 = 1.0 a0 = 1.0
a2 = math.exp(-2 * math.pi * bw_Hz / sample_rate) a2 = torch.exp(-2 * math.pi * bw_Hz / sample_rate)
a1 = -4 * a2 / (1 + a2) * math.cos(w0) a1 = -4 * a2 / (1 + a2) * torch.cos(w0)
b0 = math.sqrt(1 - a1 * a1 / (4 * a2)) * (1 - a2) b0 = torch.sqrt(1 - a1 * a1 / (4 * a2)) * (1 - a2)
if noise: if noise:
mult = math.sqrt(((1 + a2) * (1 + a2) - a1 * a1) * (1 - a2) / (1 + a2)) / b0 mult = torch.sqrt(((1 + a2) * (1 + a2) - a1 * a1) * (1 - a2) / (1 + a2)) / b0
b0 *= mult b0 = mult * b0
b1 = 0.0 b1 = 0.0
b2 = 0.0 b2 = 0.0
...@@ -151,8 +162,8 @@ def bandpass_biquad( ...@@ -151,8 +162,8 @@ def bandpass_biquad(
Args: Args:
waveform (Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz) central_freq (float or torch.Tensor): central frequency (in Hz)
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
const_skirt_gain (bool, optional) : If ``True``, uses a constant skirt gain (peak gain = Q). const_skirt_gain (bool, optional) : If ``True``, uses a constant skirt gain (peak gain = Q).
If ``False``, uses a constant 0dB peak gain. (Default: ``False``) If ``False``, uses a constant 0dB peak gain. (Default: ``False``)
...@@ -163,15 +174,20 @@ def bandpass_biquad( ...@@ -163,15 +174,20 @@ def bandpass_biquad(
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
""" """
dtype = waveform.dtype
device = waveform.device
central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
Q = torch.as_tensor(Q, dtype=dtype, device=device)
w0 = 2 * math.pi * central_freq / sample_rate w0 = 2 * math.pi * central_freq / sample_rate
alpha = math.sin(w0) / 2 / Q alpha = torch.sin(w0) / 2 / Q
temp = math.sin(w0) / 2 if const_skirt_gain else alpha temp = torch.sin(w0) / 2 if const_skirt_gain else alpha
b0 = temp b0 = temp
b1 = 0.0 b1 = 0.0
b2 = -temp b2 = -temp
a0 = 1 + alpha a0 = 1 + alpha
a1 = -2 * math.cos(w0) a1 = -2 * torch.cos(w0)
a2 = 1 - alpha a2 = 1 - alpha
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
...@@ -184,8 +200,8 @@ def bandreject_biquad( ...@@ -184,8 +200,8 @@ def bandreject_biquad(
Args: Args:
waveform (Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz) central_freq (float or torch.Tensor): central frequency (in Hz)
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns: Returns:
Tensor: Waveform of dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
...@@ -194,14 +210,19 @@ def bandreject_biquad( ...@@ -194,14 +210,19 @@ def bandreject_biquad(
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
""" """
dtype = waveform.dtype
device = waveform.device
central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
Q = torch.as_tensor(Q, dtype=dtype, device=device)
w0 = 2 * math.pi * central_freq / sample_rate w0 = 2 * math.pi * central_freq / sample_rate
alpha = math.sin(w0) / 2 / Q alpha = torch.sin(w0) / 2 / Q
b0 = 1.0 b0 = 1.0
b1 = -2 * math.cos(w0) b1 = -2 * torch.cos(w0)
b2 = 1.0 b2 = 1.0
a0 = 1 + alpha a0 = 1 + alpha
a1 = -2 * math.cos(w0) a1 = -2 * torch.cos(w0)
a2 = 1 - alpha a2 = 1 - alpha
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
...@@ -218,9 +239,9 @@ def bass_biquad( ...@@ -218,9 +239,9 @@ def bass_biquad(
Args: Args:
waveform (Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
gain (float): desired gain at the boost (or attenuation) in dB. gain (float or torch.Tensor): desired gain at the boost (or attenuation) in dB.
central_freq (float, optional): central frequency (in Hz). (Default: ``100``) central_freq (float or torch.Tensor, optional): central frequency (in Hz). (Default: ``100``)
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``). Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``).
Returns: Returns:
Tensor: Waveform of dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
...@@ -229,13 +250,19 @@ def bass_biquad( ...@@ -229,13 +250,19 @@ def bass_biquad(
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
""" """
dtype = waveform.dtype
device = waveform.device
central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
Q = torch.as_tensor(Q, dtype=dtype, device=device)
gain = torch.as_tensor(gain, dtype=dtype, device=device)
w0 = 2 * math.pi * central_freq / sample_rate w0 = 2 * math.pi * central_freq / sample_rate
alpha = math.sin(w0) / 2 / Q alpha = torch.sin(w0) / 2 / Q
A = math.exp(gain / 40 * math.log(10)) A = torch.exp(gain / 40 * math.log(10))
temp1 = 2 * math.sqrt(A) * alpha temp1 = 2 * torch.sqrt(A) * alpha
temp2 = (A - 1) * math.cos(w0) temp2 = (A - 1) * torch.cos(w0)
temp3 = (A + 1) * math.cos(w0) temp3 = (A + 1) * torch.cos(w0)
b0 = A * ((A + 1) - temp2 + temp1) b0 = A * ((A + 1) - temp2 + temp1)
b1 = 2 * A * ((A - 1) - temp3) b1 = 2 * A * ((A - 1) - temp3)
...@@ -255,12 +282,12 @@ def biquad( ...@@ -255,12 +282,12 @@ def biquad(
Args: Args:
waveform (Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
b0 (float): numerator coefficient of current input, x[n] b0 (float or torch.Tensor): numerator coefficient of current input, x[n]
b1 (float): numerator coefficient of input one time step ago x[n-1] b1 (float or torch.Tensor): numerator coefficient of input one time step ago x[n-1]
b2 (float): numerator coefficient of input two time steps ago x[n-2] b2 (float or torch.Tensor): numerator coefficient of input two time steps ago x[n-2]
a0 (float): denominator coefficient of current output y[n], typically 1 a0 (float or torch.Tensor): denominator coefficient of current output y[n], typically 1
a1 (float): denominator coefficient of current output y[n-1] a1 (float or torch.Tensor): denominator coefficient of current output y[n-1]
a2 (float): denominator coefficient of current output y[n-2] a2 (float or torch.Tensor): denominator coefficient of current output y[n-2]
Returns: Returns:
Tensor: Waveform with dimension of `(..., time)` Tensor: Waveform with dimension of `(..., time)`
...@@ -269,10 +296,17 @@ def biquad( ...@@ -269,10 +296,17 @@ def biquad(
device = waveform.device device = waveform.device
dtype = waveform.dtype dtype = waveform.dtype
b0 = torch.as_tensor(b0, dtype=dtype, device=device).view(1)
b1 = torch.as_tensor(b1, dtype=dtype, device=device).view(1)
b2 = torch.as_tensor(b2, dtype=dtype, device=device).view(1)
a0 = torch.as_tensor(a0, dtype=dtype, device=device).view(1)
a1 = torch.as_tensor(a1, dtype=dtype, device=device).view(1)
a2 = torch.as_tensor(a2, dtype=dtype, device=device).view(1)
output_waveform = lfilter( output_waveform = lfilter(
waveform, waveform,
torch.tensor([a0, a1, a2], dtype=dtype, device=device), torch.cat([a0, a1, a2]),
torch.tensor([b0, b1, b2], dtype=dtype, device=device), torch.cat([b0, b1, b2]),
) )
return output_waveform return output_waveform
...@@ -584,21 +618,27 @@ def equalizer_biquad( ...@@ -584,21 +618,27 @@ def equalizer_biquad(
waveform (Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
center_freq (float): filter's central frequency center_freq (float): filter's central frequency
gain (float): desired gain at the boost (or attenuation) in dB gain (float or torch.Tensor): desired gain at the boost (or attenuation) in dB
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns: Returns:
Tensor: Waveform of dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
""" """
dtype = waveform.dtype
device = waveform.device
center_freq = torch.as_tensor(center_freq, dtype=dtype, device=device)
Q = torch.as_tensor(Q, dtype=dtype, device=device)
gain = torch.as_tensor(gain, dtype=dtype, device=device)
w0 = 2 * math.pi * center_freq / sample_rate w0 = 2 * math.pi * center_freq / sample_rate
A = math.exp(gain / 40.0 * math.log(10)) A = torch.exp(gain / 40.0 * math.log(10))
alpha = math.sin(w0) / 2 / Q alpha = torch.sin(w0) / 2 / Q
b0 = 1 + alpha * A b0 = 1 + alpha * A
b1 = -2 * math.cos(w0) b1 = -2 * torch.cos(w0)
b2 = 1 - alpha * A b2 = 1 - alpha * A
a0 = 1 + alpha / A a0 = 1 + alpha / A
a1 = -2 * math.cos(w0) a1 = -2 * torch.cos(w0)
a2 = 1 - alpha / A a2 = 1 - alpha / A
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
...@@ -790,20 +830,25 @@ def highpass_biquad( ...@@ -790,20 +830,25 @@ def highpass_biquad(
Args: Args:
waveform (Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency cutoff_freq (float or torch.Tensor): filter cutoff frequency
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns: Returns:
Tensor: Waveform dimension of `(..., time)` Tensor: Waveform dimension of `(..., time)`
""" """
dtype = waveform.dtype
device = waveform.device
cutoff_freq = torch.as_tensor(cutoff_freq, dtype=dtype, device=device)
Q = torch.as_tensor(Q, dtype=dtype, device=device)
w0 = 2 * math.pi * cutoff_freq / sample_rate w0 = 2 * math.pi * cutoff_freq / sample_rate
alpha = math.sin(w0) / 2.0 / Q alpha = torch.sin(w0) / 2.0 / Q
b0 = (1 + math.cos(w0)) / 2 b0 = (1 + torch.cos(w0)) / 2
b1 = -1 - math.cos(w0) b1 = -1 - torch.cos(w0)
b2 = b0 b2 = b0
a0 = 1 + alpha a0 = 1 + alpha
a1 = -2 * math.cos(w0) a1 = -2 * torch.cos(w0)
a2 = 1 - alpha a2 = 1 - alpha
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
...@@ -924,20 +969,25 @@ def lowpass_biquad( ...@@ -924,20 +969,25 @@ def lowpass_biquad(
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)` waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency cutoff_freq (float or torch.Tensor): filter cutoff frequency
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns: Returns:
Tensor: Waveform of dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
""" """
dtype = waveform.dtype
device = waveform.device
cutoff_freq = torch.as_tensor(cutoff_freq, dtype=dtype, device=device)
Q = torch.as_tensor(Q, dtype=dtype, device=device)
w0 = 2 * math.pi * cutoff_freq / sample_rate w0 = 2 * math.pi * cutoff_freq / sample_rate
alpha = math.sin(w0) / 2 / Q alpha = torch.sin(w0) / 2 / Q
b0 = (1 - math.cos(w0)) / 2 b0 = (1 - torch.cos(w0)) / 2
b1 = 1 - math.cos(w0) b1 = 1 - torch.cos(w0)
b2 = b0 b2 = b0
a0 = 1 + alpha a0 = 1 + alpha
a1 = -2 * math.cos(w0) a1 = -2 * torch.cos(w0)
a2 = 1 - alpha a2 = 1 - alpha
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
...@@ -1176,9 +1226,9 @@ def treble_biquad( ...@@ -1176,9 +1226,9 @@ def treble_biquad(
Args: Args:
waveform (Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
gain (float): desired gain at the boost (or attenuation) in dB. gain (float or torch.Tensor): desired gain at the boost (or attenuation) in dB.
central_freq (float, optional): central frequency (in Hz). (Default: ``3000``) central_freq (float or torch.Tensor, optional): central frequency (in Hz). (Default: ``3000``)
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``). Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``).
Returns: Returns:
Tensor: Waveform of dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
...@@ -1187,13 +1237,19 @@ def treble_biquad( ...@@ -1187,13 +1237,19 @@ def treble_biquad(
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
""" """
dtype = waveform.dtype
device = waveform.device
central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
Q = torch.as_tensor(Q, dtype=dtype, device=device)
gain = torch.as_tensor(gain, dtype=dtype, device=device)
w0 = 2 * math.pi * central_freq / sample_rate w0 = 2 * math.pi * central_freq / sample_rate
alpha = math.sin(w0) / 2 / Q alpha = torch.sin(w0) / 2 / Q
A = math.exp(gain / 40 * math.log(10)) A = torch.exp(gain / 40 * math.log(10))
temp1 = 2 * math.sqrt(A) * alpha temp1 = 2 * torch.sqrt(A) * alpha
temp2 = (A - 1) * math.cos(w0) temp2 = (A - 1) * torch.cos(w0)
temp3 = (A + 1) * math.cos(w0) temp3 = (A + 1) * torch.cos(w0)
b0 = A * ((A + 1) + temp2 + temp1) b0 = A * ((A + 1) + temp2 + temp1)
b1 = -2 * A * ((A - 1) + temp3) b1 = -2 * A * ((A - 1) + temp3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment