Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
f5dbb002
Unverified
Commit
f5dbb002
authored
Jul 16, 2021
by
nateanl
Committed by
GitHub
Jul 16, 2021
Browse files
Add PitchShift to functional and transform (#1629)
parent
0ea6d10d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
173 additions
and
0 deletions
+173
-0
docs/source/functional.rst
docs/source/functional.rst
+5
-0
docs/source/transforms.rst
docs/source/transforms.rst
+7
-0
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+10
-0
test/torchaudio_unittest/transforms/batch_consistency_test.py
.../torchaudio_unittest/transforms/batch_consistency_test.py
+12
-0
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
...audio_unittest/transforms/torchscript_consistency_impl.py
+9
-0
torchaudio/functional/__init__.py
torchaudio/functional/__init__.py
+2
-0
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+74
-0
torchaudio/transforms.py
torchaudio/transforms.py
+54
-0
No files found.
docs/source/functional.rst
View file @
f5dbb002
...
@@ -211,6 +211,11 @@ vad
...
@@ -211,6 +211,11 @@ vad
.. autofunction:: phase_vocoder
.. autofunction:: phase_vocoder
:hidden:`pitch_shift`
-----------------------
.. autofunction:: pitch_shift
:hidden:`compute_deltas`
:hidden:`compute_deltas`
------------------------
------------------------
...
...
docs/source/transforms.rst
View file @
f5dbb002
...
@@ -101,6 +101,13 @@ Transforms are common audio transforms. They can be chained together using :clas
...
@@ -101,6 +101,13 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward
.. automethod:: forward
:hidden:`PitchShift`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: PitchShift
.. automethod:: forward
:hidden:`Fade`
:hidden:`Fade`
~~~~~~~~~~~~~~
~~~~~~~~~~~~~~
...
...
test/torchaudio_unittest/functional/functional_impl.py
View file @
f5dbb002
...
@@ -422,6 +422,16 @@ class Functional(TestBaseMixin):
...
@@ -422,6 +422,16 @@ class Functional(TestBaseMixin):
assert
F
.
edit_distance
(
seq1
,
seq2
)
==
distance
assert
F
.
edit_distance
(
seq1
,
seq2
)
==
distance
assert
F
.
edit_distance
(
seq2
,
seq1
)
==
distance
assert
F
.
edit_distance
(
seq2
,
seq1
)
==
distance
@
nested_params
(
[
-
4
,
-
2
,
0
,
2
,
4
],
)
def
test_pitch_shift_shape
(
self
,
n_steps
):
sample_rate
=
16000
torch
.
random
.
manual_seed
(
42
)
waveform
=
torch
.
rand
(
2
,
44100
*
1
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
waveform_shift
=
F
.
pitch_shift
(
waveform
,
sample_rate
,
n_steps
)
assert
waveform
.
size
()
==
waveform_shift
.
size
()
class
FunctionalCPUOnly
(
TestBaseMixin
):
class
FunctionalCPUOnly
(
TestBaseMixin
):
def
test_create_fb_matrix_no_warning_high_n_freq
(
self
):
def
test_create_fb_matrix_no_warning_high_n_freq
(
self
):
...
...
test/torchaudio_unittest/transforms/batch_consistency_test.py
View file @
f5dbb002
...
@@ -187,3 +187,15 @@ class TestTransforms(common_utils.TorchaudioTestCase):
...
@@ -187,3 +187,15 @@ class TestTransforms(common_utils.TorchaudioTestCase):
# Batch then transform
# Batch then transform
computed
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
)(
waveform
.
repeat
(
3
,
1
,
1
))
computed
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_pitch_shift
(
self
):
sample_rate
=
44100
n_steps
=
4
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
PitchShift
(
sample_rate
,
n_steps
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
PitchShift
(
sample_rate
,
n_steps
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
View file @
f5dbb002
...
@@ -135,3 +135,12 @@ class Transforms(TempDirMixin, TestBaseMixin):
...
@@ -135,3 +135,12 @@ class Transforms(TempDirMixin, TestBaseMixin):
tensor
,
tensor
,
test_pseudo_complex
test_pseudo_complex
)
)
def
test_PitchShift
(
self
):
sample_rate
=
8000
n_steps
=
4
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
self
.
_assert_consistency
(
T
.
PitchShift
(
sample_rate
=
sample_rate
,
n_steps
=
n_steps
),
waveform
)
torchaudio/functional/__init__.py
View file @
f5dbb002
...
@@ -21,6 +21,7 @@ from .functional import (
...
@@ -21,6 +21,7 @@ from .functional import (
apply_codec
,
apply_codec
,
resample
,
resample
,
edit_distance
,
edit_distance
,
pitch_shift
,
)
)
from
.filtering
import
(
from
.filtering
import
(
allpass_biquad
,
allpass_biquad
,
...
@@ -90,4 +91,5 @@ __all__ = [
...
@@ -90,4 +91,5 @@ __all__ = [
'apply_codec'
,
'apply_codec'
,
'resample'
,
'resample'
,
'edit_distance'
,
'edit_distance'
,
'pitch_shift'
,
]
]
torchaudio/functional/functional.py
View file @
f5dbb002
...
@@ -36,6 +36,7 @@ __all__ = [
...
@@ -36,6 +36,7 @@ __all__ = [
"apply_codec"
,
"apply_codec"
,
"resample"
,
"resample"
,
"edit_distance"
,
"edit_distance"
,
"pitch_shift"
,
]
]
...
@@ -1488,3 +1489,76 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
...
@@ -1488,3 +1489,76 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
dnew
,
dold
=
dold
,
dnew
dnew
,
dold
=
dold
,
dnew
return
int
(
dold
[
-
1
])
return
int
(
dold
[
-
1
])
def
pitch_shift
(
waveform
:
Tensor
,
sample_rate
:
int
,
n_steps
:
int
,
bins_per_octave
:
int
=
12
,
n_fft
:
int
=
512
,
win_length
:
Optional
[
int
]
=
None
,
hop_length
:
Optional
[
int
]
=
None
,
window
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""
Shift the pitch of a waveform by ``n_steps`` steps.
Args:
waveform (Tensor): The input waveform of shape `(..., time)`.
sample_rate (float): Sample rate of `waveform`.
n_steps (int): The (fractional) steps to shift `waveform`.
bins_per_octave (int, optional): The number of steps per octave (Default: ``12``).
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``).
win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``).
hop_length (int or None, optional): Length of hop between STFT windows. If None, then
``win_length // 4`` is used (Default: ``None``).
window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window.
If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).
Returns:
Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
"""
if
hop_length
is
None
:
hop_length
=
n_fft
//
4
if
win_length
is
None
:
win_length
=
n_fft
if
window
is
None
:
window
=
torch
.
hann_window
(
window_length
=
win_length
,
device
=
waveform
.
device
)
# pack batch
shape
=
waveform
.
size
()
waveform
=
waveform
.
reshape
(
-
1
,
shape
[
-
1
])
ori_len
=
shape
[
-
1
]
rate
=
2.0
**
(
-
float
(
n_steps
)
/
bins_per_octave
)
spec_f
=
torch
.
stft
(
input
=
waveform
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
window
=
window
,
center
=
True
,
pad_mode
=
'reflect'
,
normalized
=
False
,
onesided
=
True
,
return_complex
=
True
)
phase_advance
=
torch
.
linspace
(
0
,
math
.
pi
*
hop_length
,
spec_f
.
shape
[
-
2
],
device
=
spec_f
.
device
)[...,
None
]
spec_stretch
=
phase_vocoder
(
spec_f
,
rate
,
phase_advance
)
len_stretch
=
int
(
round
(
ori_len
/
rate
))
waveform_stretch
=
torch
.
istft
(
spec_stretch
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
window
=
window
,
length
=
len_stretch
)
waveform_shift
=
resample
(
waveform_stretch
,
sample_rate
/
rate
,
float
(
sample_rate
))
shift_len
=
waveform_shift
.
size
()[
-
1
]
if
shift_len
>
ori_len
:
waveform_shift
=
waveform_shift
[...,
:
ori_len
]
else
:
waveform_shift
=
torch
.
nn
.
functional
.
pad
(
waveform_shift
,
[
0
,
ori_len
-
shift_len
])
# unpack batch
waveform_shift
=
waveform_shift
.
view
(
shape
[:
-
1
]
+
waveform_shift
.
shape
[
-
1
:])
return
waveform_shift
torchaudio/transforms.py
View file @
f5dbb002
...
@@ -34,6 +34,7 @@ __all__ = [
...
@@ -34,6 +34,7 @@ __all__ = [
'SpectralCentroid'
,
'SpectralCentroid'
,
'Vol'
,
'Vol'
,
'ComputeDeltas'
,
'ComputeDeltas'
,
'PitchShift'
,
]
]
...
@@ -1210,3 +1211,56 @@ class SpectralCentroid(torch.nn.Module):
...
@@ -1210,3 +1211,56 @@ class SpectralCentroid(torch.nn.Module):
return
F
.
spectral_centroid
(
waveform
,
self
.
sample_rate
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
return
F
.
spectral_centroid
(
waveform
,
self
.
sample_rate
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
self
.
win_length
)
self
.
win_length
)
class
PitchShift
(
torch
.
nn
.
Module
):
r
"""Shift the pitch of a waveform by ``n_steps`` steps.
Args:
waveform (Tensor): The input waveform of shape `(..., time)`.
sample_rate (float): Sample rate of `waveform`.
n_steps (int): The (fractional) steps to shift `waveform`.
bins_per_octave (int, optional): The number of steps per octave (Default : ``12``).
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``).
win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``).
hop_length (int or None, optional): Length of hop between STFT windows. If None, then ``win_length // 4``
is used (Default: ``None``).
window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window.
If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> waveform_shift = transforms.PitchShift(sample_rate, 4)(waveform) # (channel, time)
"""
__constants__
=
[
'sample_rate'
,
'n_steps'
,
'bins_per_octave'
,
'n_fft'
,
'win_length'
,
'hop_length'
]
def
__init__
(
self
,
sample_rate
:
int
,
n_steps
:
int
,
bins_per_octave
:
int
=
12
,
n_fft
:
int
=
512
,
win_length
:
Optional
[
int
]
=
None
,
hop_length
:
Optional
[
int
]
=
None
,
window_fn
:
Callable
[...,
Tensor
]
=
torch
.
hann_window
,
wkwargs
:
Optional
[
dict
]
=
None
)
->
None
:
super
(
PitchShift
,
self
).
__init__
()
self
.
n_steps
=
n_steps
self
.
bins_per_octave
=
bins_per_octave
self
.
sample_rate
=
sample_rate
self
.
n_fft
=
n_fft
self
.
win_length
=
win_length
if
win_length
is
not
None
else
n_fft
self
.
hop_length
=
hop_length
if
hop_length
is
not
None
else
self
.
win_length
//
4
window
=
window_fn
(
self
.
win_length
)
if
wkwargs
is
None
else
window_fn
(
self
.
win_length
,
**
wkwargs
)
self
.
register_buffer
(
'window'
,
window
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: The pitch-shifted audio of shape `(..., time)`.
"""
return
F
.
pitch_shift
(
waveform
,
self
.
sample_rate
,
self
.
n_steps
,
self
.
bins_per_octave
,
self
.
n_fft
,
self
.
win_length
,
self
.
hop_length
,
self
.
window
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment