Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
7078fcd3
"vscode:/vscode.git/clone" did not exist on "087cbff1512f56a3d2f84477e14416edd62dabac"
Unverified
Commit
7078fcd3
authored
May 19, 2021
by
Caroline Chen
Committed by
GitHub
May 19, 2021
Browse files
Add kaiser window support to resampling (#1509)
parent
b8b732af
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
70 additions
and
33 deletions
+70
-33
test/torchaudio_unittest/compliance_kaldi_test.py
test/torchaudio_unittest/compliance_kaldi_test.py
+28
-15
test/torchaudio_unittest/transforms/transforms_test.py
test/torchaudio_unittest/transforms/transforms_test.py
+4
-2
torchaudio/compliance/kaldi.py
torchaudio/compliance/kaldi.py
+4
-2
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+25
-6
torchaudio/transforms.py
torchaudio/transforms.py
+9
-8
No files found.
test/torchaudio_unittest/compliance_kaldi_test.py
View file @
7078fcd3
...
@@ -7,6 +7,7 @@ import torchaudio.compliance.kaldi as kaldi
...
@@ -7,6 +7,7 @@ import torchaudio.compliance.kaldi as kaldi
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest
import
common_utils
from
.compliance
import
utils
as
compliance_utils
from
.compliance
import
utils
as
compliance_utils
from
parameterized
import
parameterized
def
extract_window
(
window
,
wave
,
f
,
frame_length
,
frame_shift
,
snip_edges
):
def
extract_window
(
window
,
wave
,
f
,
frame_length
,
frame_shift
,
snip_edges
):
...
@@ -182,20 +183,26 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -182,20 +183,26 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
self
.
_compliance_test_helper
(
self
.
test2_filepath
,
'resample'
,
32
,
3
,
get_output_fn
,
atol
=
1e-2
,
rtol
=
1e-5
)
self
.
_compliance_test_helper
(
self
.
test2_filepath
,
'resample'
,
32
,
3
,
get_output_fn
,
atol
=
1e-2
,
rtol
=
1e-5
)
def
test_resample_waveform_upsample_size
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
upsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
*
2
)
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
upsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
*
2
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
upsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
)
*
2
)
self
.
assertTrue
(
upsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
)
*
2
)
def
test_resample_waveform_downsample_size
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
)
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
)
//
2
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
)
//
2
)
def
test_resample_waveform_identity_size
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
)
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
))
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
))
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
atol
=
1e-1
,
rtol
=
1e-4
):
resampling_method
=
"sinc_interpolation"
,
atol
=
1e-1
,
rtol
=
1e-4
):
# resample the signal and compare it to the ground truth
# resample the signal and compare it to the ground truth
n_to_trim
=
20
n_to_trim
=
20
sample_rate
=
1000
sample_rate
=
1000
...
@@ -211,7 +218,8 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -211,7 +218,8 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
original_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
sample_rate
)
original_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
sample_rate
)
sound
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
original_timestamps
).
unsqueeze
(
0
)
sound
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
original_timestamps
).
unsqueeze
(
0
)
estimate
=
kaldi
.
resample_waveform
(
sound
,
sample_rate
,
new_sample_rate
).
squeeze
()
estimate
=
kaldi
.
resample_waveform
(
sound
,
sample_rate
,
new_sample_rate
,
resampling_method
=
resampling_method
).
squeeze
()
new_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
new_sample_rate
)[:
estimate
.
size
(
0
)]
new_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
new_sample_rate
)[:
estimate
.
size
(
0
)]
ground_truth
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
new_timestamps
)
ground_truth
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
new_timestamps
)
...
@@ -222,15 +230,18 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -222,15 +230,18 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
self
.
assertEqual
(
estimate
,
ground_truth
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
estimate
,
ground_truth
,
atol
=
atol
,
rtol
=
rtol
)
def
test_resample_waveform_downsample_accuracy
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_downsample_accuracy
(
self
,
resampling_method
):
for
i
in
range
(
1
,
20
):
for
i
in
range
(
1
,
20
):
self
.
_test_resample_waveform_accuracy
(
down_scale_factor
=
i
*
2
)
self
.
_test_resample_waveform_accuracy
(
down_scale_factor
=
i
*
2
,
resampling_method
=
resampling_method
)
def
test_resample_waveform_upsample_accuracy
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_upsample_accuracy
(
self
,
resampling_method
):
for
i
in
range
(
1
,
20
):
for
i
in
range
(
1
,
20
):
self
.
_test_resample_waveform_accuracy
(
up_scale_factor
=
1.0
+
i
/
20.0
)
self
.
_test_resample_waveform_accuracy
(
up_scale_factor
=
1.0
+
i
/
20.0
,
resampling_method
=
resampling_method
)
def
test_resample_waveform_multi_channel
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_multi_channel
(
self
,
resampling_method
):
num_channels
=
3
num_channels
=
3
multi_sound
=
self
.
test1_signal
.
repeat
(
num_channels
,
1
)
# (num_channels, 8000 smp)
multi_sound
=
self
.
test1_signal
.
repeat
(
num_channels
,
1
)
# (num_channels, 8000 smp)
...
@@ -238,11 +249,13 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -238,11 +249,13 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
for
i
in
range
(
num_channels
):
for
i
in
range
(
num_channels
):
multi_sound
[
i
,
:]
*=
(
i
+
1
)
*
1.5
multi_sound
[
i
,
:]
*=
(
i
+
1
)
*
1.5
multi_sound_sampled
=
kaldi
.
resample_waveform
(
multi_sound
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
)
multi_sound_sampled
=
kaldi
.
resample_waveform
(
multi_sound
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
,
resampling_method
=
resampling_method
)
# check that sampling is same whether using separately or in a tensor of size (c, n)
# check that sampling is same whether using separately or in a tensor of size (c, n)
for
i
in
range
(
num_channels
):
for
i
in
range
(
num_channels
):
single_channel
=
self
.
test1_signal
*
(
i
+
1
)
*
1.5
single_channel
=
self
.
test1_signal
*
(
i
+
1
)
*
1.5
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test1_signal_sr
,
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
)
self
.
test1_signal_sr
//
2
,
resampling_method
=
resampling_method
)
self
.
assertEqual
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
self
.
assertEqual
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
test/torchaudio_unittest/transforms/transforms_test.py
View file @
7078fcd3
...
@@ -169,9 +169,11 @@ class Tester(common_utils.TorchaudioTestCase):
...
@@ -169,9 +169,11 @@ class Tester(common_utils.TorchaudioTestCase):
upsample_rate
=
sample_rate
*
2
upsample_rate
=
sample_rate
*
2
downsample_rate
=
sample_rate
//
2
downsample_rate
=
sample_rate
//
2
invalid_
resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
'foo'
)
invalid_resampling_method
=
'foo'
self
.
assertRaises
(
ValueError
,
invalid_resample
,
waveform
)
with
self
.
assertRaises
(
ValueError
):
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
invalid_resampling_method
)
upsample_resample
=
torchaudio
.
transforms
.
Resample
(
upsample_resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
'sinc_interpolation'
)
sample_rate
,
upsample_rate
,
resampling_method
=
'sinc_interpolation'
)
...
...
torchaudio/compliance/kaldi.py
View file @
7078fcd3
...
@@ -756,7 +756,8 @@ def resample_waveform(waveform: Tensor,
...
@@ -756,7 +756,8 @@ def resample_waveform(waveform: Tensor,
orig_freq
:
float
,
orig_freq
:
float
,
new_freq
:
float
,
new_freq
:
float
,
lowpass_filter_width
:
int
=
6
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
)
->
Tensor
:
rolloff
:
float
=
0.99
,
resampling_method
:
str
=
"sinc_interpolation"
)
->
Tensor
:
r
"""Resamples the waveform at the new frequency.
r
"""Resamples the waveform at the new frequency.
This is a wrapper around ``torchaudio.functional.resample``.
This is a wrapper around ``torchaudio.functional.resample``.
...
@@ -773,4 +774,5 @@ def resample_waveform(waveform: Tensor,
...
@@ -773,4 +774,5 @@ def resample_waveform(waveform: Tensor,
Returns:
Returns:
Tensor: The waveform at the new frequency
Tensor: The waveform at the new frequency
"""
"""
return
torchaudio
.
functional
.
resample
(
waveform
,
orig_freq
,
new_freq
,
lowpass_filter_width
,
rolloff
)
return
torchaudio
.
functional
.
resample
(
waveform
,
orig_freq
,
new_freq
,
lowpass_filter_width
,
rolloff
,
resampling_method
)
torchaudio/functional/functional.py
View file @
7078fcd3
...
@@ -1303,7 +1303,9 @@ def _get_sinc_resample_kernel(
...
@@ -1303,7 +1303,9 @@ def _get_sinc_resample_kernel(
new_freq
:
float
,
new_freq
:
float
,
gcd
:
int
,
gcd
:
int
,
lowpass_filter_width
:
int
,
lowpass_filter_width
:
int
,
rolloff
:
float
):
rolloff
:
float
,
resampling_method
:
str
,
beta
:
Optional
[
float
]):
if
not
(
int
(
orig_freq
)
==
orig_freq
and
int
(
new_freq
)
==
new_freq
):
if
not
(
int
(
orig_freq
)
==
orig_freq
and
int
(
new_freq
)
==
new_freq
):
warnings
.
warn
(
warnings
.
warn
(
...
@@ -1318,9 +1320,15 @@ def _get_sinc_resample_kernel(
...
@@ -1318,9 +1320,15 @@ def _get_sinc_resample_kernel(
"https://github.com/pytorch/audio/issues/1487."
"https://github.com/pytorch/audio/issues/1487."
)
)
if
resampling_method
not
in
[
'sinc_interpolation'
,
'kaiser_window'
]:
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
resampling_method
))
orig_freq
=
int
(
orig_freq
)
//
gcd
orig_freq
=
int
(
orig_freq
)
//
gcd
new_freq
=
int
(
new_freq
)
//
gcd
new_freq
=
int
(
new_freq
)
//
gcd
if
resampling_method
==
"kaiser_window"
and
beta
is
None
:
beta
=
14.769656459379492
assert
lowpass_filter_width
>
0
assert
lowpass_filter_width
>
0
kernels
=
[]
kernels
=
[]
base_freq
=
min
(
orig_freq
,
new_freq
)
base_freq
=
min
(
orig_freq
,
new_freq
)
...
@@ -1352,15 +1360,20 @@ def _get_sinc_resample_kernel(
...
@@ -1352,15 +1360,20 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
# future work.
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
)
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
,
dtype
=
torch
.
float64
)
for
i
in
range
(
new_freq
):
for
i
in
range
(
new_freq
):
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
t
=
t
.
clamp_
(
-
lowpass_filter_width
,
lowpass_filter_width
)
t
=
t
.
clamp_
(
-
lowpass_filter_width
,
lowpass_filter_width
)
t
*=
math
.
pi
# we do not use
torch.hann_
window here as we need to evaluate the window
# we do not use
built in torch
window
s
here as we need to evaluate the window
# at specific positions, not over a regular grid.
# at specific positions, not over a regular grid.
window
=
torch
.
cos
(
t
/
lowpass_filter_width
/
2
)
**
2
if
resampling_method
==
"sinc_interpolation"
:
window
=
torch
.
cos
(
t
*
math
.
pi
/
lowpass_filter_width
/
2
)
**
2
elif
resampling_method
==
"kaiser_window"
:
beta
=
torch
.
tensor
(
beta
,
dtype
=
float
)
window
=
torch
.
i0
(
beta
*
torch
.
sqrt
(
1
-
(
t
/
lowpass_filter_width
)
**
2
))
/
torch
.
i0
(
beta
)
t
*=
math
.
pi
kernel
=
torch
.
where
(
t
==
0
,
torch
.
tensor
(
1.
).
to
(
t
),
torch
.
sin
(
t
)
/
t
)
kernel
=
torch
.
where
(
t
==
0
,
torch
.
tensor
(
1.
).
to
(
t
),
torch
.
sin
(
t
)
/
t
)
kernel
.
mul_
(
window
)
kernel
.
mul_
(
window
)
kernels
.
append
(
kernel
)
kernels
.
append
(
kernel
)
...
@@ -1403,6 +1416,8 @@ def resample(
...
@@ -1403,6 +1416,8 @@ def resample(
new_freq
:
float
,
new_freq
:
float
,
lowpass_filter_width
:
int
=
6
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
,
rolloff
:
float
=
0.99
,
resampling_method
:
str
=
"sinc_interpolation"
,
beta
:
Optional
[
float
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
r
"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
r
"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
...
@@ -1421,6 +1436,9 @@ def resample(
...
@@ -1421,6 +1436,9 @@ def resample(
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
resampling_method (str, optional): The resampling method.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
beta (float, optional): The shape parameter used for kaiser window.
Returns:
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Tensor: The waveform at the new frequency of dimension (..., time).
...
@@ -1433,6 +1451,7 @@ def resample(
...
@@ -1433,6 +1451,7 @@ def resample(
gcd
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
gcd
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
gcd
,
lowpass_filter_width
,
rolloff
)
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
gcd
,
lowpass_filter_width
,
rolloff
,
resampling_method
,
beta
)
resampled
=
_apply_sinc_resample_kernel
(
waveform
,
orig_freq
,
new_freq
,
gcd
,
kernel
,
width
)
resampled
=
_apply_sinc_resample_kernel
(
waveform
,
orig_freq
,
new_freq
,
gcd
,
kernel
,
width
)
return
resampled
return
resampled
torchaudio/transforms.py
View file @
7078fcd3
...
@@ -657,11 +657,13 @@ class Resample(torch.nn.Module):
...
@@ -657,11 +657,13 @@ class Resample(torch.nn.Module):
Args:
Args:
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
resampling_method (str, optional): The resampling method.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
beta (float, optional): The shape parameter used for kaiser window.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -669,7 +671,8 @@ class Resample(torch.nn.Module):
...
@@ -669,7 +671,8 @@ class Resample(torch.nn.Module):
new_freq
:
float
=
16000
,
new_freq
:
float
=
16000
,
resampling_method
:
str
=
'sinc_interpolation'
,
resampling_method
:
str
=
'sinc_interpolation'
,
lowpass_filter_width
:
int
=
6
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
)
->
None
:
rolloff
:
float
=
0.99
,
beta
:
Optional
[
float
]
=
None
)
->
None
:
super
(
Resample
,
self
).
__init__
()
super
(
Resample
,
self
).
__init__
()
self
.
orig_freq
=
orig_freq
self
.
orig_freq
=
orig_freq
...
@@ -680,7 +683,8 @@ class Resample(torch.nn.Module):
...
@@ -680,7 +683,8 @@ class Resample(torch.nn.Module):
self
.
rolloff
=
rolloff
self
.
rolloff
=
rolloff
self
.
kernel
,
self
.
width
=
_get_sinc_resample_kernel
(
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
kernel
,
self
.
width
=
_get_sinc_resample_kernel
(
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
lowpass_filter_width
,
self
.
rolloff
)
self
.
lowpass_filter_width
,
self
.
rolloff
,
self
.
resampling_method
,
beta
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
...
@@ -690,12 +694,9 @@ class Resample(torch.nn.Module):
...
@@ -690,12 +694,9 @@ class Resample(torch.nn.Module):
Returns:
Returns:
Tensor: Output signal of dimension (..., time).
Tensor: Output signal of dimension (..., time).
"""
"""
if
self
.
resampling_method
==
'sinc_interpolation'
:
return
_apply_sinc_resample_kernel
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
return
_apply_sinc_resample_kernel
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
kernel
,
self
.
width
)
self
.
kernel
,
self
.
width
)
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
self
.
resampling_method
))
class
ComplexNorm
(
torch
.
nn
.
Module
):
class
ComplexNorm
(
torch
.
nn
.
Module
):
r
"""Compute the norm of complex tensor input.
r
"""Compute the norm of complex tensor input.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment