Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
7078fcd3
Unverified
Commit
7078fcd3
authored
May 19, 2021
by
Caroline Chen
Committed by
GitHub
May 19, 2021
Browse files
Add kaiser window support to resampling (#1509)
parent
b8b732af
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
70 additions
and
33 deletions
+70
-33
test/torchaudio_unittest/compliance_kaldi_test.py
test/torchaudio_unittest/compliance_kaldi_test.py
+28
-15
test/torchaudio_unittest/transforms/transforms_test.py
test/torchaudio_unittest/transforms/transforms_test.py
+4
-2
torchaudio/compliance/kaldi.py
torchaudio/compliance/kaldi.py
+4
-2
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+25
-6
torchaudio/transforms.py
torchaudio/transforms.py
+9
-8
No files found.
test/torchaudio_unittest/compliance_kaldi_test.py
View file @
7078fcd3
...
@@ -7,6 +7,7 @@ import torchaudio.compliance.kaldi as kaldi
...
@@ -7,6 +7,7 @@ import torchaudio.compliance.kaldi as kaldi
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest
import
common_utils
from
.compliance
import
utils
as
compliance_utils
from
.compliance
import
utils
as
compliance_utils
from
parameterized
import
parameterized
def
extract_window
(
window
,
wave
,
f
,
frame_length
,
frame_shift
,
snip_edges
):
def
extract_window
(
window
,
wave
,
f
,
frame_length
,
frame_shift
,
snip_edges
):
...
@@ -182,20 +183,26 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -182,20 +183,26 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
self
.
_compliance_test_helper
(
self
.
test2_filepath
,
'resample'
,
32
,
3
,
get_output_fn
,
atol
=
1e-2
,
rtol
=
1e-5
)
self
.
_compliance_test_helper
(
self
.
test2_filepath
,
'resample'
,
32
,
3
,
get_output_fn
,
atol
=
1e-2
,
rtol
=
1e-5
)
def
test_resample_waveform_upsample_size
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
upsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
*
2
)
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
upsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
*
2
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
upsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
)
*
2
)
self
.
assertTrue
(
upsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
)
*
2
)
def
test_resample_waveform_downsample_size
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
)
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
)
//
2
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
)
//
2
)
def
test_resample_waveform_identity_size
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
)
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test1_signal
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
))
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test1_signal
.
size
(
-
1
))
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
atol
=
1e-1
,
rtol
=
1e-4
):
resampling_method
=
"sinc_interpolation"
,
atol
=
1e-1
,
rtol
=
1e-4
):
# resample the signal and compare it to the ground truth
# resample the signal and compare it to the ground truth
n_to_trim
=
20
n_to_trim
=
20
sample_rate
=
1000
sample_rate
=
1000
...
@@ -211,7 +218,8 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -211,7 +218,8 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
original_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
sample_rate
)
original_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
sample_rate
)
sound
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
original_timestamps
).
unsqueeze
(
0
)
sound
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
original_timestamps
).
unsqueeze
(
0
)
estimate
=
kaldi
.
resample_waveform
(
sound
,
sample_rate
,
new_sample_rate
).
squeeze
()
estimate
=
kaldi
.
resample_waveform
(
sound
,
sample_rate
,
new_sample_rate
,
resampling_method
=
resampling_method
).
squeeze
()
new_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
new_sample_rate
)[:
estimate
.
size
(
0
)]
new_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
new_sample_rate
)[:
estimate
.
size
(
0
)]
ground_truth
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
new_timestamps
)
ground_truth
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
new_timestamps
)
...
@@ -222,15 +230,18 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -222,15 +230,18 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
self
.
assertEqual
(
estimate
,
ground_truth
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
estimate
,
ground_truth
,
atol
=
atol
,
rtol
=
rtol
)
def
test_resample_waveform_downsample_accuracy
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_downsample_accuracy
(
self
,
resampling_method
):
for
i
in
range
(
1
,
20
):
for
i
in
range
(
1
,
20
):
self
.
_test_resample_waveform_accuracy
(
down_scale_factor
=
i
*
2
)
self
.
_test_resample_waveform_accuracy
(
down_scale_factor
=
i
*
2
,
resampling_method
=
resampling_method
)
def
test_resample_waveform_upsample_accuracy
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_upsample_accuracy
(
self
,
resampling_method
):
for
i
in
range
(
1
,
20
):
for
i
in
range
(
1
,
20
):
self
.
_test_resample_waveform_accuracy
(
up_scale_factor
=
1.0
+
i
/
20.0
)
self
.
_test_resample_waveform_accuracy
(
up_scale_factor
=
1.0
+
i
/
20.0
,
resampling_method
=
resampling_method
)
def
test_resample_waveform_multi_channel
(
self
):
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_multi_channel
(
self
,
resampling_method
):
num_channels
=
3
num_channels
=
3
multi_sound
=
self
.
test1_signal
.
repeat
(
num_channels
,
1
)
# (num_channels, 8000 smp)
multi_sound
=
self
.
test1_signal
.
repeat
(
num_channels
,
1
)
# (num_channels, 8000 smp)
...
@@ -238,11 +249,13 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -238,11 +249,13 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
for
i
in
range
(
num_channels
):
for
i
in
range
(
num_channels
):
multi_sound
[
i
,
:]
*=
(
i
+
1
)
*
1.5
multi_sound
[
i
,
:]
*=
(
i
+
1
)
*
1.5
multi_sound_sampled
=
kaldi
.
resample_waveform
(
multi_sound
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
)
multi_sound_sampled
=
kaldi
.
resample_waveform
(
multi_sound
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
,
resampling_method
=
resampling_method
)
# check that sampling is same whether using separately or in a tensor of size (c, n)
# check that sampling is same whether using separately or in a tensor of size (c, n)
for
i
in
range
(
num_channels
):
for
i
in
range
(
num_channels
):
single_channel
=
self
.
test1_signal
*
(
i
+
1
)
*
1.5
single_channel
=
self
.
test1_signal
*
(
i
+
1
)
*
1.5
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test1_signal_sr
,
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
)
self
.
test1_signal_sr
//
2
,
resampling_method
=
resampling_method
)
self
.
assertEqual
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
self
.
assertEqual
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
test/torchaudio_unittest/transforms/transforms_test.py
View file @
7078fcd3
...
@@ -169,9 +169,11 @@ class Tester(common_utils.TorchaudioTestCase):
...
@@ -169,9 +169,11 @@ class Tester(common_utils.TorchaudioTestCase):
upsample_rate
=
sample_rate
*
2
upsample_rate
=
sample_rate
*
2
downsample_rate
=
sample_rate
//
2
downsample_rate
=
sample_rate
//
2
invalid_
resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
'foo'
)
invalid_resampling_method
=
'foo'
self
.
assertRaises
(
ValueError
,
invalid_resample
,
waveform
)
with
self
.
assertRaises
(
ValueError
):
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
invalid_resampling_method
)
upsample_resample
=
torchaudio
.
transforms
.
Resample
(
upsample_resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
'sinc_interpolation'
)
sample_rate
,
upsample_rate
,
resampling_method
=
'sinc_interpolation'
)
...
...
torchaudio/compliance/kaldi.py
View file @
7078fcd3
...
@@ -756,7 +756,8 @@ def resample_waveform(waveform: Tensor,
...
@@ -756,7 +756,8 @@ def resample_waveform(waveform: Tensor,
orig_freq
:
float
,
orig_freq
:
float
,
new_freq
:
float
,
new_freq
:
float
,
lowpass_filter_width
:
int
=
6
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
)
->
Tensor
:
rolloff
:
float
=
0.99
,
resampling_method
:
str
=
"sinc_interpolation"
)
->
Tensor
:
r
"""Resamples the waveform at the new frequency.
r
"""Resamples the waveform at the new frequency.
This is a wrapper around ``torchaudio.functional.resample``.
This is a wrapper around ``torchaudio.functional.resample``.
...
@@ -773,4 +774,5 @@ def resample_waveform(waveform: Tensor,
...
@@ -773,4 +774,5 @@ def resample_waveform(waveform: Tensor,
Returns:
Returns:
Tensor: The waveform at the new frequency
Tensor: The waveform at the new frequency
"""
"""
return
torchaudio
.
functional
.
resample
(
waveform
,
orig_freq
,
new_freq
,
lowpass_filter_width
,
rolloff
)
return
torchaudio
.
functional
.
resample
(
waveform
,
orig_freq
,
new_freq
,
lowpass_filter_width
,
rolloff
,
resampling_method
)
torchaudio/functional/functional.py
View file @
7078fcd3
...
@@ -1303,7 +1303,9 @@ def _get_sinc_resample_kernel(
...
@@ -1303,7 +1303,9 @@ def _get_sinc_resample_kernel(
new_freq
:
float
,
new_freq
:
float
,
gcd
:
int
,
gcd
:
int
,
lowpass_filter_width
:
int
,
lowpass_filter_width
:
int
,
rolloff
:
float
):
rolloff
:
float
,
resampling_method
:
str
,
beta
:
Optional
[
float
]):
if
not
(
int
(
orig_freq
)
==
orig_freq
and
int
(
new_freq
)
==
new_freq
):
if
not
(
int
(
orig_freq
)
==
orig_freq
and
int
(
new_freq
)
==
new_freq
):
warnings
.
warn
(
warnings
.
warn
(
...
@@ -1318,9 +1320,15 @@ def _get_sinc_resample_kernel(
...
@@ -1318,9 +1320,15 @@ def _get_sinc_resample_kernel(
"https://github.com/pytorch/audio/issues/1487."
"https://github.com/pytorch/audio/issues/1487."
)
)
if
resampling_method
not
in
[
'sinc_interpolation'
,
'kaiser_window'
]:
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
resampling_method
))
orig_freq
=
int
(
orig_freq
)
//
gcd
orig_freq
=
int
(
orig_freq
)
//
gcd
new_freq
=
int
(
new_freq
)
//
gcd
new_freq
=
int
(
new_freq
)
//
gcd
if
resampling_method
==
"kaiser_window"
and
beta
is
None
:
beta
=
14.769656459379492
assert
lowpass_filter_width
>
0
assert
lowpass_filter_width
>
0
kernels
=
[]
kernels
=
[]
base_freq
=
min
(
orig_freq
,
new_freq
)
base_freq
=
min
(
orig_freq
,
new_freq
)
...
@@ -1352,15 +1360,20 @@ def _get_sinc_resample_kernel(
...
@@ -1352,15 +1360,20 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
# future work.
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
)
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
,
dtype
=
torch
.
float64
)
for
i
in
range
(
new_freq
):
for
i
in
range
(
new_freq
):
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
t
=
t
.
clamp_
(
-
lowpass_filter_width
,
lowpass_filter_width
)
t
=
t
.
clamp_
(
-
lowpass_filter_width
,
lowpass_filter_width
)
t
*=
math
.
pi
# we do not use
torch.hann_
window here as we need to evaluate the window
# we do not use
built in torch
window
s
here as we need to evaluate the window
# at specific positions, not over a regular grid.
# at specific positions, not over a regular grid.
window
=
torch
.
cos
(
t
/
lowpass_filter_width
/
2
)
**
2
if
resampling_method
==
"sinc_interpolation"
:
window
=
torch
.
cos
(
t
*
math
.
pi
/
lowpass_filter_width
/
2
)
**
2
elif
resampling_method
==
"kaiser_window"
:
beta
=
torch
.
tensor
(
beta
,
dtype
=
float
)
window
=
torch
.
i0
(
beta
*
torch
.
sqrt
(
1
-
(
t
/
lowpass_filter_width
)
**
2
))
/
torch
.
i0
(
beta
)
t
*=
math
.
pi
kernel
=
torch
.
where
(
t
==
0
,
torch
.
tensor
(
1.
).
to
(
t
),
torch
.
sin
(
t
)
/
t
)
kernel
=
torch
.
where
(
t
==
0
,
torch
.
tensor
(
1.
).
to
(
t
),
torch
.
sin
(
t
)
/
t
)
kernel
.
mul_
(
window
)
kernel
.
mul_
(
window
)
kernels
.
append
(
kernel
)
kernels
.
append
(
kernel
)
...
@@ -1403,6 +1416,8 @@ def resample(
...
@@ -1403,6 +1416,8 @@ def resample(
new_freq
:
float
,
new_freq
:
float
,
lowpass_filter_width
:
int
=
6
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
,
rolloff
:
float
=
0.99
,
resampling_method
:
str
=
"sinc_interpolation"
,
beta
:
Optional
[
float
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
r
"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
r
"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
...
@@ -1421,6 +1436,9 @@ def resample(
...
@@ -1421,6 +1436,9 @@ def resample(
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
resampling_method (str, optional): The resampling method.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
beta (float, optional): The shape parameter used for kaiser window.
Returns:
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Tensor: The waveform at the new frequency of dimension (..., time).
...
@@ -1433,6 +1451,7 @@ def resample(
...
@@ -1433,6 +1451,7 @@ def resample(
gcd
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
gcd
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
gcd
,
lowpass_filter_width
,
rolloff
)
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
gcd
,
lowpass_filter_width
,
rolloff
,
resampling_method
,
beta
)
resampled
=
_apply_sinc_resample_kernel
(
waveform
,
orig_freq
,
new_freq
,
gcd
,
kernel
,
width
)
resampled
=
_apply_sinc_resample_kernel
(
waveform
,
orig_freq
,
new_freq
,
gcd
,
kernel
,
width
)
return
resampled
return
resampled
torchaudio/transforms.py
View file @
7078fcd3
...
@@ -657,11 +657,13 @@ class Resample(torch.nn.Module):
...
@@ -657,11 +657,13 @@ class Resample(torch.nn.Module):
Args:
Args:
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
resampling_method (str, optional): The resampling method.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
beta (float, optional): The shape parameter used for kaiser window.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -669,7 +671,8 @@ class Resample(torch.nn.Module):
...
@@ -669,7 +671,8 @@ class Resample(torch.nn.Module):
new_freq
:
float
=
16000
,
new_freq
:
float
=
16000
,
resampling_method
:
str
=
'sinc_interpolation'
,
resampling_method
:
str
=
'sinc_interpolation'
,
lowpass_filter_width
:
int
=
6
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
)
->
None
:
rolloff
:
float
=
0.99
,
beta
:
Optional
[
float
]
=
None
)
->
None
:
super
(
Resample
,
self
).
__init__
()
super
(
Resample
,
self
).
__init__
()
self
.
orig_freq
=
orig_freq
self
.
orig_freq
=
orig_freq
...
@@ -680,7 +683,8 @@ class Resample(torch.nn.Module):
...
@@ -680,7 +683,8 @@ class Resample(torch.nn.Module):
self
.
rolloff
=
rolloff
self
.
rolloff
=
rolloff
self
.
kernel
,
self
.
width
=
_get_sinc_resample_kernel
(
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
kernel
,
self
.
width
=
_get_sinc_resample_kernel
(
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
lowpass_filter_width
,
self
.
rolloff
)
self
.
lowpass_filter_width
,
self
.
rolloff
,
self
.
resampling_method
,
beta
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
...
@@ -690,12 +694,9 @@ class Resample(torch.nn.Module):
...
@@ -690,12 +694,9 @@ class Resample(torch.nn.Module):
Returns:
Returns:
Tensor: Output signal of dimension (..., time).
Tensor: Output signal of dimension (..., time).
"""
"""
if
self
.
resampling_method
==
'sinc_interpolation'
:
return
_apply_sinc_resample_kernel
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
return
_apply_sinc_resample_kernel
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
kernel
,
self
.
width
)
self
.
kernel
,
self
.
width
)
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
self
.
resampling_method
))
class
ComplexNorm
(
torch
.
nn
.
Module
):
class
ComplexNorm
(
torch
.
nn
.
Module
):
r
"""Compute the norm of complex tensor input.
r
"""Compute the norm of complex tensor input.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment