Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
c90c18d7
Unverified
Commit
c90c18d7
authored
Nov 22, 2019
by
Vincent QB
Committed by
GitHub
Nov 22, 2019
Browse files
Move batch from vocoder transform to functional (#350)
* fixing errors in docstring. * move batch to functional.
parent
c74e580f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
26 deletions
+32
-26
torchaudio/functional.py
torchaudio/functional.py
+17
-3
torchaudio/transforms.py
torchaudio/transforms.py
+15
-23
No files found.
torchaudio/functional.py
View file @
c90c18d7
...
...
@@ -469,13 +469,13 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
factor of ``rate``.
Args:
complex_specgrams (torch.Tensor): Dimension of `(
channel
, freq, time, complex=2)`
complex_specgrams (torch.Tensor): Dimension of `(
...
, freq, time, complex=2)`
rate (float): Speed-up factor
phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension
of (freq, 1)
Returns:
complex_specgrams_stretch (torch.Tensor): Dimension of `(
channel
,
complex_specgrams_stretch (torch.Tensor): Dimension of `(
...
,
freq, ceil(time/rate), complex=2)`
Example
...
...
@@ -490,6 +490,10 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
torch.Size([2, 1025, 231, 2])
"""
# pack batch
shape
=
complex_specgrams
.
size
()
complex_specgrams
=
complex_specgrams
.
reshape
([
-
1
]
+
list
(
shape
[
-
3
:]))
time_steps
=
torch
.
arange
(
0
,
complex_specgrams
.
size
(
-
2
),
rate
,
...
...
@@ -527,6 +531,9 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
complex_specgrams_stretch
=
torch
.
stack
([
real_stretch
,
imag_stretch
],
dim
=-
1
)
# unpack batch
complex_specgrams_stretch
=
complex_specgrams_stretch
.
reshape
(
shape
[:
-
3
]
+
complex_specgrams_stretch
.
shape
[
1
:])
return
complex_specgrams_stretch
...
...
@@ -775,6 +782,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
"""
# pack batch
shape
=
specgram
.
size
()
specgram
=
specgram
.
reshape
([
-
1
]
+
list
(
shape
[
-
2
:]))
value
=
torch
.
rand
(
1
)
*
mask_param
min_value
=
torch
.
rand
(
1
)
*
(
specgram
.
size
(
axis
)
-
value
)
...
...
@@ -789,7 +800,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
else
:
raise
ValueError
(
'Only Frequency and Time masking are supported'
)
return
specgram
# unpack batch
specgram
=
specgram
.
reshape
(
shape
[:
-
2
]
+
specgram
.
shape
[
-
2
:])
return
specgram
.
reshape
(
shape
[:
-
2
]
+
specgram
.
shape
[
-
2
:])
def
compute_deltas
(
specgram
,
win_length
=
5
,
mode
=
"replicate"
):
...
...
torchaudio/transforms.py
View file @
c90c18d7
...
...
@@ -380,9 +380,9 @@ class ComplexNorm(torch.nn.Module):
def
forward
(
self
,
complex_tensor
):
r
"""
Args:
complex_tensor (Tensor): Tensor shape of `(
*
, complex=2)`
complex_tensor (Tensor): Tensor shape of `(
...
, complex=2)`
Returns:
Tensor: norm of the input tensor, shape of `(
*
, )`
Tensor: norm of the input tensor, shape of `(
...
, )`
"""
return
F
.
complex_norm
(
complex_tensor
,
self
.
power
)
...
...
@@ -438,14 +438,14 @@ class TimeStretch(torch.jit.ScriptModule):
# type: (Tensor, Optional[float]) -> Tensor
r
"""
Args:
complex_specgrams (Tensor): complex spectrogram (
*, channel
, freq, time, complex=2)
complex_specgrams (Tensor): complex spectrogram (
...
, freq, time, complex=2)
overriding_rate (float or None): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``
Returns:
(Tensor): Stretched complex spectrogram of dimension (
*, channel
, freq, ceil(time/rate), complex=2)
(Tensor): Stretched complex spectrogram of dimension (
...
, freq, ceil(time/rate), complex=2)
"""
assert
complex_specgrams
.
size
(
-
1
)
==
2
,
"complex_specgrams should be a complex tensor, shape (
*
, complex=2)"
assert
complex_specgrams
.
size
(
-
1
)
==
2
,
"complex_specgrams should be a complex tensor, shape (
...
, complex=2)"
if
overriding_rate
is
None
:
rate
=
self
.
fixed_rate
...
...
@@ -458,16 +458,12 @@ class TimeStretch(torch.jit.ScriptModule):
if
rate
==
1.0
:
return
complex_specgrams
shape
=
complex_specgrams
.
size
()
complex_specgrams
=
complex_specgrams
.
reshape
([
-
1
]
+
list
(
shape
[
-
3
:]))
complex_specgrams
=
F
.
phase_vocoder
(
complex_specgrams
,
rate
,
self
.
phase_advance
)
return
complex_specgrams
.
reshape
(
shape
[:
-
3
]
+
complex_specgrams
.
shape
[
-
3
:])
return
F
.
phase_vocoder
(
complex_specgrams
,
rate
,
self
.
phase_advance
)
class
_AxisMasking
(
torch
.
nn
.
Module
):
r
"""
Apply masking to a spectrogram.
r
"""
Apply masking to a spectrogram.
Args:
mask_param (int): Maximum possible length of the mask
axis: What dimension the mask is applied on
...
...
@@ -486,26 +482,22 @@ class _AxisMasking(torch.nn.Module):
# type: (Tensor, float) -> Tensor
r
"""
Args:
specgram (torch.Tensor): Tensor of dimension (
*, channel
, freq, time)
specgram (torch.Tensor): Tensor of dimension (
...
, freq, time)
Returns:
torch.Tensor: Masked spectrogram of dimensions (
*, channel
, freq, time)
torch.Tensor: Masked spectrogram of dimensions (
...
, freq, time)
"""
# if iid_masks flag marked and specgram has a batch dimension
if
self
.
iid_masks
and
specgram
.
dim
()
==
4
:
return
F
.
mask_along_axis_iid
(
specgram
,
self
.
mask_param
,
mask_value
,
self
.
axis
+
1
)
else
:
shape
=
specgram
.
size
()
specgram
=
specgram
.
reshape
([
-
1
]
+
list
(
shape
[
-
2
:]))
specgram
=
F
.
mask_along_axis
(
specgram
,
self
.
mask_param
,
mask_value
,
self
.
axis
)
return
specgram
.
reshape
(
shape
[:
-
2
]
+
specgram
.
shape
[
-
2
:])
return
F
.
mask_along_axis
(
specgram
,
self
.
mask_param
,
mask_value
,
self
.
axis
)
class
FrequencyMasking
(
_AxisMasking
):
r
"""
Apply masking to a spectrogram in the frequency domain.
r
"""
Apply masking to a spectrogram in the frequency domain.
Args:
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
...
...
@@ -518,8 +510,8 @@ class FrequencyMasking(_AxisMasking):
class
TimeMasking
(
_AxisMasking
):
r
"""
Apply masking to a spectrogram in the time domain.
r
"""
Apply masking to a spectrogram in the time domain.
Args:
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment