Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
954d5121
Unverified
Commit
954d5121
authored
Apr 22, 2020
by
moto
Committed by
GitHub
Apr 22, 2020
Browse files
Make sliding_window_cmn batch-aware (#570)
parent
38287a75
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
13 deletions
+28
-13
test/test_batch_consistency.py
test/test_batch_consistency.py
+7
-0
torchaudio/functional.py
torchaudio/functional.py
+21
-13
No files found.
test/test_batch_consistency.py
View file @
954d5121
...
...
@@ -72,6 +72,13 @@ class TestFunctional(unittest.TestCase):
waveform
=
torch
.
rand
(
2
,
100
)
-
0.5
_test_batch
(
F
.
dcshift
,
waveform
,
shift
=
0.5
,
limiter_gain
=
0.05
)
def
test_sliding_window_cmn
(
self
):
waveform
=
torch
.
randn
(
2
,
1024
)
-
0.5
_test_batch
(
F
.
sliding_window_cmn
,
waveform
,
center
=
True
,
norm_vars
=
True
)
_test_batch
(
F
.
sliding_window_cmn
,
waveform
,
center
=
True
,
norm_vars
=
False
)
_test_batch
(
F
.
sliding_window_cmn
,
waveform
,
center
=
False
,
norm_vars
=
True
)
_test_batch
(
F
.
sliding_window_cmn
,
waveform
,
center
=
False
,
norm_vars
=
False
)
class
TestTransforms
(
unittest
.
TestCase
):
"""Test suite for classes defined in `transforms` module"""
...
...
torchaudio/functional.py
View file @
954d5121
...
...
@@ -1713,14 +1713,18 @@ def sliding_window_cmn(
Returns:
Tensor: Tensor of freq of dimension (..., frame)
"""
input_shape
=
waveform
.
shape
num_frames
,
num_feats
=
input_shape
[
-
2
:]
waveform
=
waveform
.
view
(
-
1
,
num_frames
,
num_feats
)
num_channels
=
waveform
.
shape
[
0
]
dtype
=
waveform
.
dtype
device
=
waveform
.
device
last_window_start
=
last_window_end
=
-
1
num_frames
,
num_feats
=
waveform
.
shape
cur_sum
=
torch
.
zeros
(
num_feats
,
dtype
=
dtype
,
device
=
device
)
cur_sumsq
=
torch
.
zeros
(
num_feats
,
dtype
=
dtype
,
device
=
device
)
cur_sum
=
torch
.
zeros
(
num_channels
,
num_feats
,
dtype
=
dtype
,
device
=
device
)
cur_sumsq
=
torch
.
zeros
(
num_channels
,
num_feats
,
dtype
=
dtype
,
device
=
device
)
cmn_waveform
=
torch
.
zeros
(
num_frames
,
num_feats
,
dtype
=
dtype
,
device
=
device
)
num_channels
,
num_frames
,
num_feats
,
dtype
=
dtype
,
device
=
device
)
for
t
in
range
(
num_frames
):
window_start
=
0
window_end
=
0
...
...
@@ -1742,33 +1746,37 @@ def sliding_window_cmn(
if
window_start
<
0
:
window_start
=
0
if
last_window_start
==
-
1
:
input_part
=
waveform
[
window_start
:
window_end
-
window_start
]
cur_sum
+=
torch
.
sum
(
input_part
,
0
)
input_part
=
waveform
[
:,
window_start
:
window_end
-
window_start
,
:
]
cur_sum
+=
torch
.
sum
(
input_part
,
1
)
if
norm_vars
:
cur_sumsq
+=
torch
.
cumsum
(
input_part
**
2
,
0
)[
-
1
]
cur_sumsq
+=
torch
.
cumsum
(
input_part
**
2
,
1
)[
:,
-
1
,
:
]
else
:
if
window_start
>
last_window_start
:
frame_to_remove
=
waveform
[
last_window_start
]
frame_to_remove
=
waveform
[
:,
last_window_start
,
:
]
cur_sum
-=
frame_to_remove
if
norm_vars
:
cur_sumsq
-=
(
frame_to_remove
**
2
)
if
window_end
>
last_window_end
:
frame_to_add
=
waveform
[
last_window_end
]
frame_to_add
=
waveform
[
:,
last_window_end
,
:
]
cur_sum
+=
frame_to_add
if
norm_vars
:
cur_sumsq
+=
(
frame_to_add
**
2
)
window_frames
=
window_end
-
window_start
last_window_start
=
window_start
last_window_end
=
window_end
cmn_waveform
[
t
]
=
waveform
[
t
]
-
cur_sum
/
window_frames
cmn_waveform
[
:,
t
,
:
]
=
waveform
[
:,
t
,
:
]
-
cur_sum
/
window_frames
if
norm_vars
:
if
window_frames
==
1
:
cmn_waveform
[
t
]
=
torch
.
zeros
(
num_feats
,
dtype
=
dtype
,
device
=
device
)
cmn_waveform
[
:,
t
,
:
]
=
torch
.
zeros
(
num_channels
,
num_feats
,
dtype
=
dtype
,
device
=
device
)
else
:
variance
=
cur_sumsq
variance
=
variance
/
window_frames
variance
-=
((
cur_sum
**
2
)
/
(
window_frames
**
2
))
variance
=
torch
.
pow
(
variance
,
-
0.5
)
cmn_waveform
[
t
]
*=
variance
cmn_waveform
[:,
t
,
:]
*=
variance
cmn_waveform
=
cmn_waveform
.
view
(
input_shape
[:
-
2
]
+
(
num_frames
,
num_feats
))
if
len
(
input_shape
)
==
2
:
cmn_waveform
=
cmn_waveform
.
squeeze
(
0
)
return
cmn_waveform
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment