Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
2d879132
Unverified
Commit
2d879132
authored
Oct 12, 2020
by
moto
Committed by
GitHub
Oct 12, 2020
Browse files
Add wsj0-mix dataset to source separation example (#895)
parent
ba7b7a2f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
266 additions
and
0 deletions
+266
-0
examples/source_separation/utils/__init__.py
examples/source_separation/utils/__init__.py
+1
-0
examples/source_separation/utils/dataset/__init__.py
examples/source_separation/utils/dataset/__init__.py
+1
-0
examples/source_separation/utils/dataset/utils.py
examples/source_separation/utils/dataset/utils.py
+83
-0
examples/source_separation/utils/dataset/wsj0mix.py
examples/source_separation/utils/dataset/wsj0mix.py
+70
-0
test/torchaudio_unittest/example/souce_sepration/wsj0mix_test.py
...rchaudio_unittest/example/souce_sepration/wsj0mix_test.py
+111
-0
No files found.
examples/source_separation/utils/__init__.py
View file @
2d879132
from
.
import
(
from
.
import
(
dataset
,
metrics
,
metrics
,
)
)
examples/source_separation/utils/dataset/__init__.py
0 → 100644
View file @
2d879132
from
.
import
utils
,
wsj0mix
examples/source_separation/utils/dataset/utils.py
0 → 100644
View file @
2d879132
from
typing
import
List
from
functools
import
partial
from
collections
import
namedtuple
import
torch
from
.
import
wsj0mix
Batch
=
namedtuple
(
"Batch"
,
[
"mix"
,
"src"
,
"mask"
])
def
get_dataset
(
dataset_type
,
root_dir
,
num_speakers
,
sample_rate
):
if
dataset_type
==
"wsj0mix"
:
train
=
wsj0mix
.
WSJ0Mix
(
root_dir
/
"tr"
,
num_speakers
,
sample_rate
)
validation
=
wsj0mix
.
WSJ0Mix
(
root_dir
/
"cv"
,
num_speakers
,
sample_rate
)
evaluation
=
wsj0mix
.
WSJ0Mix
(
root_dir
/
"tt"
,
num_speakers
,
sample_rate
)
else
:
raise
ValueError
(
f
"Unexpected dataset:
{
dataset_type
}
"
)
return
train
,
validation
,
evaluation
def
_fix_num_frames
(
sample
:
wsj0mix
.
SampleType
,
target_num_frames
:
int
,
random_start
=
False
):
"""Ensure waveform has exact number of frames by slicing or padding"""
mix
=
sample
[
1
]
# [1, num_frames]
src
=
torch
.
cat
(
sample
[
2
],
0
)
# [num_sources, num_frames]
num_channels
,
num_frames
=
src
.
shape
if
num_frames
>=
target_num_frames
:
if
random_start
and
num_frames
>
target_num_frames
:
start_frame
=
torch
.
randint
(
num_frames
-
target_num_frames
,
[
1
])
mix
=
mix
[:,
start_frame
:]
src
=
src
[:,
start_frame
:]
mix
=
mix
[:,
:
target_num_frames
]
src
=
src
[:,
:
target_num_frames
]
mask
=
torch
.
ones_like
(
mix
)
else
:
num_padding
=
target_num_frames
-
num_frames
pad
=
torch
.
zeros
([
1
,
num_padding
],
dtype
=
mix
.
dtype
,
device
=
mix
.
device
)
mix
=
torch
.
cat
([
mix
,
pad
],
1
)
src
=
torch
.
cat
([
src
,
pad
.
expand
(
num_channels
,
-
1
)],
1
)
mask
=
torch
.
ones_like
(
mix
)
mask
[...,
num_frames
:]
=
0
return
mix
,
src
,
mask
def
collate_fn_wsj0mix_train
(
samples
:
List
[
wsj0mix
.
SampleType
],
sample_rate
,
duration
):
target_num_frames
=
int
(
duration
*
sample_rate
)
mixes
,
srcs
,
masks
=
[],
[],
[]
for
sample
in
samples
:
mix
,
src
,
mask
=
_fix_num_frames
(
sample
,
target_num_frames
,
random_start
=
True
)
mixes
.
append
(
mix
)
srcs
.
append
(
src
)
masks
.
append
(
mask
)
return
Batch
(
torch
.
stack
(
mixes
,
0
),
torch
.
stack
(
srcs
,
0
),
torch
.
stack
(
masks
,
0
))
def
collate_fn_wsj0mix_test
(
samples
:
List
[
wsj0mix
.
SampleType
]):
max_num_frames
=
max
(
s
[
1
].
shape
[
-
1
]
for
s
in
samples
)
mixes
,
srcs
,
masks
=
[],
[],
[]
for
sample
in
samples
:
mix
,
src
,
mask
=
_fix_num_frames
(
sample
,
max_num_frames
,
random_start
=
False
)
mixes
.
append
(
mix
)
srcs
.
append
(
src
)
masks
.
append
(
mask
)
return
Batch
(
torch
.
stack
(
mixes
,
0
),
torch
.
stack
(
srcs
,
0
),
torch
.
stack
(
masks
,
0
))
def
get_collate_fn
(
dataset_type
,
mode
,
sample_rate
=
None
,
duration
=
4
):
assert
mode
in
[
"train"
,
"test"
]
if
dataset_type
==
"wsj0mix"
:
if
mode
==
'train'
:
if
sample_rate
is
None
:
raise
ValueError
(
"sample_rate is not given."
)
return
partial
(
collate_fn_wsj0mix_train
,
sample_rate
=
sample_rate
,
duration
=
duration
)
return
collate_fn_wsj0mix_test
raise
ValueError
(
f
"Unexpected dataset:
{
dataset_type
}
"
)
examples/source_separation/utils/dataset/wsj0mix.py
0 → 100644
View file @
2d879132
from
pathlib
import
Path
from
typing
import
Union
,
Tuple
,
List
import
torch
from
torch.utils.data
import
Dataset
import
torchaudio
SampleType
=
Tuple
[
int
,
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
class
WSJ0Mix
(
Dataset
):
"""Create a Dataset for wsj0-mix.
Args:
root (str or Path): Path to the directory where the dataset is found.
num_speakers (int): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios.
sample_rate (int): Expected sample rate of audio files. If any of the audio has a
different sample rate, raises ``ValueError``.
audio_ext (str): The extension of audio files to find. (default: ".wav")
"""
def
__init__
(
self
,
root
:
Union
[
str
,
Path
],
num_speakers
:
int
,
sample_rate
:
int
,
audio_ext
:
str
=
".wav"
,
):
self
.
root
=
Path
(
root
)
self
.
sample_rate
=
sample_rate
self
.
mix_dir
=
(
self
.
root
/
"mix"
).
resolve
()
self
.
src_dirs
=
[(
self
.
root
/
f
"s
{
i
+
1
}
"
).
resolve
()
for
i
in
range
(
num_speakers
)]
self
.
files
=
[
p
.
name
for
p
in
self
.
mix_dir
.
glob
(
f
"*
{
audio_ext
}
"
)]
self
.
files
.
sort
()
def
_load_audio
(
self
,
path
)
->
torch
.
Tensor
:
waveform
,
sample_rate
=
torchaudio
.
load
(
path
)
if
sample_rate
!=
self
.
sample_rate
:
raise
ValueError
(
f
"The dataset contains audio file of sample rate
{
sample_rate
}
. "
"Where the requested sample rate is {self.sample_rate}."
)
return
waveform
def
_load_sample
(
self
,
filename
)
->
SampleType
:
mixed
=
self
.
_load_audio
(
str
(
self
.
mix_dir
/
filename
))
srcs
=
[]
for
i
,
dir_
in
enumerate
(
self
.
src_dirs
):
src
=
self
.
_load_audio
(
str
(
dir_
/
filename
))
if
mixed
.
shape
!=
src
.
shape
:
raise
ValueError
(
f
"Different waveform shapes. mixed:
{
mixed
.
shape
}
, src[
{
i
}
]:
{
src
.
shape
}
"
)
srcs
.
append
(
src
)
return
self
.
sample_rate
,
mixed
,
srcs
def
__len__
(
self
)
->
int
:
return
len
(
self
.
files
)
def
__getitem__
(
self
,
key
:
int
)
->
SampleType
:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)``
"""
return
self
.
_load_sample
(
self
.
files
[
key
])
test/torchaudio_unittest/example/souce_sepration/wsj0mix_test.py
0 → 100644
View file @
2d879132
import
os
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
TorchaudioTestCase
,
get_whitenoise
,
save_wav
,
normalize_wav
,
)
from
source_separation.utils.dataset
import
wsj0mix
_FILENAMES
=
[
"012c0207_1.9952_01cc0202_-1.9952.wav"
,
"01co0302_1.63_014c020q_-1.63.wav"
,
"01do0316_0.24011_205a0104_-0.24011.wav"
,
"01lc020x_1.1301_027o030r_-1.1301.wav"
,
"01mc0202_0.34056_205o0106_-0.34056.wav"
,
"01nc020t_0.53821_018o030w_-0.53821.wav"
,
"01po030f_2.2136_40ko031a_-2.2136.wav"
,
"01ra010o_2.4098_403a010f_-2.4098.wav"
,
"01xo030b_0.22377_016o031a_-0.22377.wav"
,
"02ac020x_0.68566_01ec020b_-0.68566.wav"
,
"20co010m_0.82801_019c0212_-0.82801.wav"
,
"20da010u_1.2483_017c0211_-1.2483.wav"
,
"20oo010d_1.0631_01ic020s_-1.0631.wav"
,
"20sc0107_2.0222_20fo010h_-2.0222.wav"
,
"20tc010f_0.051456_404a0110_-0.051456.wav"
,
"407c0214_1.1712_02ca0113_-1.1712.wav"
,
"40ao030w_2.4697_20vc010a_-2.4697.wav"
,
"40pa0101_1.1087_40ea0107_-1.1087.wav"
,
]
def
_mock_dataset
(
root_dir
,
num_speaker
):
dirnames
=
[
"mix"
]
+
[
f
"s
{
i
+
1
}
"
for
i
in
range
(
num_speaker
)]
for
dirname
in
dirnames
:
os
.
makedirs
(
os
.
path
.
join
(
root_dir
,
dirname
),
exist_ok
=
True
)
seed
=
0
sample_rate
=
8000
expected
=
[]
for
filename
in
_FILENAMES
:
mix
=
None
src
=
[]
for
dirname
in
dirnames
:
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
1
,
dtype
=
"int16"
,
seed
=
seed
)
seed
+=
1
path
=
os
.
path
.
join
(
root_dir
,
dirname
,
filename
)
save_wav
(
path
,
waveform
,
sample_rate
)
waveform
=
normalize_wav
(
waveform
)
if
dirname
==
"mix"
:
mix
=
waveform
else
:
src
.
append
(
waveform
)
expected
.
append
((
sample_rate
,
mix
,
src
))
return
expected
class
TestWSJ0Mix2
(
TempDirMixin
,
TorchaudioTestCase
):
backend
=
"default"
root_dir
=
None
expected
=
None
@
classmethod
def
setUpClass
(
cls
):
cls
.
root_dir
=
cls
.
get_base_temp_dir
()
cls
.
expected
=
_mock_dataset
(
cls
.
root_dir
,
2
)
def
test_wsj0mix
(
self
):
dataset
=
wsj0mix
.
WSJ0Mix
(
self
.
root_dir
,
num_speakers
=
2
,
sample_rate
=
8000
)
n_ite
=
0
for
i
,
sample
in
enumerate
(
dataset
):
(
_
,
sample_mix
,
sample_src
)
=
sample
(
_
,
expected_mix
,
expected_src
)
=
self
.
expected
[
i
]
self
.
assertEqual
(
sample_mix
,
expected_mix
,
atol
=
5e-5
,
rtol
=
1e-8
)
self
.
assertEqual
(
sample_src
[
0
],
expected_src
[
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
self
.
assertEqual
(
sample_src
[
1
],
expected_src
[
1
],
atol
=
5e-5
,
rtol
=
1e-8
)
n_ite
+=
1
assert
n_ite
==
len
(
self
.
expected
)
class
TestWSJ0Mix3
(
TempDirMixin
,
TorchaudioTestCase
):
backend
=
"default"
root_dir
=
None
expected
=
None
@
classmethod
def
setUpClass
(
cls
):
cls
.
root_dir
=
cls
.
get_base_temp_dir
()
cls
.
expected
=
_mock_dataset
(
cls
.
root_dir
,
3
)
def
test_wsj0mix
(
self
):
dataset
=
wsj0mix
.
WSJ0Mix
(
self
.
root_dir
,
num_speakers
=
3
,
sample_rate
=
8000
)
n_ite
=
0
for
i
,
sample
in
enumerate
(
dataset
):
(
_
,
sample_mix
,
sample_src
)
=
sample
(
_
,
expected_mix
,
expected_src
)
=
self
.
expected
[
i
]
self
.
assertEqual
(
sample_mix
,
expected_mix
,
atol
=
5e-5
,
rtol
=
1e-8
)
self
.
assertEqual
(
sample_src
[
0
],
expected_src
[
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
self
.
assertEqual
(
sample_src
[
1
],
expected_src
[
1
],
atol
=
5e-5
,
rtol
=
1e-8
)
self
.
assertEqual
(
sample_src
[
2
],
expected_src
[
2
],
atol
=
5e-5
,
rtol
=
1e-8
)
n_ite
+=
1
assert
n_ite
==
len
(
self
.
expected
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment