Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
dc0990c7
Commit
dc0990c7
authored
Oct 07, 2021
by
nateanl
Committed by
Zhaoheng Ni
Oct 07, 2021
Browse files
[Cherry-picked 0.10] Move LibriMix dataset to datasets directory (#1833)
parent
e6fccfda
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
59 additions
and
32 deletions
+59
-32
examples/source_separation/eval.py
examples/source_separation/eval.py
+11
-7
examples/source_separation/lightning_train.py
examples/source_separation/lightning_train.py
+18
-14
examples/source_separation/utils/dataset/__init__.py
examples/source_separation/utils/dataset/__init__.py
+2
-2
examples/source_separation/utils/dataset/utils.py
examples/source_separation/utils/dataset/utils.py
+5
-4
torchaudio/datasets/__init__.py
torchaudio/datasets/__init__.py
+2
-0
torchaudio/datasets/librimix.py
torchaudio/datasets/librimix.py
+21
-5
No files found.
examples/source_separation/eval.py
View file @
dc0990c7
from
argparse
import
ArgumentParser
import
p
ath
lib
from
pathlib
import
P
ath
from
lightning_train
import
_get_model
,
_get_dataloader
,
sisdri_metric
import
mir_eval
import
torch
def
eval
(
model
,
data_loader
,
device
):
def
_
eval
(
model
,
data_loader
,
device
):
results
=
torch
.
zeros
(
4
)
with
torch
.
no_grad
():
for
i
,
batch
in
enumerate
(
data_loader
):
for
_
,
batch
in
enumerate
(
data_loader
):
mix
,
src
,
mask
=
batch
mix
,
src
,
mask
=
mix
.
to
(
device
),
src
.
to
(
device
),
mask
.
to
(
device
)
est
=
model
(
mix
)
...
...
@@ -35,7 +35,11 @@ def eval(model, data_loader, device):
def
cli_main
():
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--dataset"
,
default
=
"librimix"
,
type
=
str
,
choices
=
[
"wsj0-mix"
,
"librimix"
])
parser
.
add_argument
(
"--data-dir"
,
default
=
pathlib
.
Path
(
"./Libri2Mix/wav8k/min"
),
type
=
pathlib
.
Path
)
parser
.
add_argument
(
"--root-dir"
,
type
=
Path
,
help
=
"The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored."
,
)
parser
.
add_argument
(
"--librimix-tr-split"
,
default
=
"train-360"
,
...
...
@@ -60,8 +64,8 @@ def cli_main():
)
parser
.
add_argument
(
"--exp-dir"
,
default
=
pathlib
.
Path
(
"./exp"
),
type
=
pathlib
.
Path
,
default
=
Path
(
"./exp"
),
type
=
Path
,
help
=
"The directory to save checkpoints and logs."
)
parser
.
add_argument
(
...
...
@@ -95,7 +99,7 @@ def cli_main():
args
.
librimix_tr_split
,
)
eval
(
model
,
eval_loader
,
device
)
_
eval
(
model
,
eval_loader
,
device
)
if
__name__
==
"__main__"
:
...
...
examples/source_separation/lightning_train.py
View file @
dc0990c7
#!/usr/bin/env python3
# pyre-strict
import
pathlib
from
pathlib
import
Path
from
argparse
import
ArgumentParser
from
typing
import
(
Any
,
...
...
@@ -13,6 +12,7 @@ from typing import (
Optional
,
Tuple
,
TypedDict
,
Union
,
)
import
torch
...
...
@@ -279,7 +279,7 @@ def _get_model(
def
_get_dataloader
(
dataset_type
:
str
,
datase
t_dir
:
pathlib
.
Path
,
roo
t_dir
:
Union
[
str
,
Path
]
,
num_speakers
:
int
=
2
,
sample_rate
:
int
=
8000
,
batch_size
:
int
=
6
,
...
...
@@ -291,11 +291,11 @@ def _get_dataloader(
Args:
dataset_type (str): the dataset to use.
datase
t_dir (
pathlib.
Path): the root directory of the dataset.
num_speakers (int): the number of speakers in the mixture. (Default: 2)
sample_rate (int): the sample rate of the audio. (Default: 8000)
batch_size (int): the batch size of the dataset. (Default: 6)
num_workers (int): the number of workers for each dataloader. (Default: 4)
roo
t_dir (
str or
Path): the root directory of the dataset.
num_speakers (int
, optional
): the number of speakers in the mixture. (Default: 2)
sample_rate (int
, optional
): the sample rate of the audio. (Default: 8000)
batch_size (int
, optional
): the batch size of the dataset. (Default: 6)
num_workers (int
, optional
): the number of workers for each dataloader. (Default: 4)
librimix_task (str or None, optional): the task in LibriMix dataset.
librimix_tr_split (str or None, optional): the training split in LibriMix dataset.
...
...
@@ -303,7 +303,7 @@ def _get_dataloader(
tuple: (train_loader, valid_loader, eval_loader)
"""
train_dataset
,
valid_dataset
,
eval_dataset
=
dataset_utils
.
get_dataset
(
dataset_type
,
datase
t_dir
,
num_speakers
,
sample_rate
,
librimix_task
,
librimix_tr_split
dataset_type
,
roo
t_dir
,
num_speakers
,
sample_rate
,
librimix_task
,
librimix_tr_split
)
train_collate_fn
=
dataset_utils
.
get_collate_fn
(
dataset_type
,
mode
=
'train'
,
sample_rate
=
sample_rate
,
duration
=
3
...
...
@@ -337,9 +337,13 @@ def _get_dataloader(
def
cli_main
():
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--batch-size"
,
default
=
3
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
6
,
type
=
int
)
parser
.
add_argument
(
"--dataset"
,
default
=
"librimix"
,
type
=
str
,
choices
=
[
"wsj0-mix"
,
"librimix"
])
parser
.
add_argument
(
"--data-dir"
,
default
=
pathlib
.
Path
(
"./Libri2Mix/wav8k/min"
),
type
=
pathlib
.
Path
)
parser
.
add_argument
(
"--root-dir"
,
type
=
Path
,
help
=
"The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored."
,
)
parser
.
add_argument
(
"--librimix-tr-split"
,
default
=
"train-360"
,
...
...
@@ -364,8 +368,8 @@ def cli_main():
)
parser
.
add_argument
(
"--exp-dir"
,
default
=
pathlib
.
Path
(
"./exp"
),
type
=
pathlib
.
Path
,
default
=
Path
(
"./exp"
),
type
=
Path
,
help
=
"The directory to save checkpoints and logs."
)
parser
.
add_argument
(
...
...
@@ -404,7 +408,7 @@ def cli_main():
)
train_loader
,
valid_loader
,
eval_loader
=
_get_dataloader
(
args
.
dataset
,
args
.
data
_dir
,
args
.
root
_dir
,
args
.
num_speakers
,
args
.
sample_rate
,
args
.
batch_size
,
...
...
examples/source_separation/utils/dataset/__init__.py
View file @
dc0990c7
from
.
import
utils
,
wsj0mix
,
librimix
from
.
import
utils
,
wsj0mix
__all__
=
[
'utils'
,
'wsj0mix'
,
'librimix'
]
__all__
=
[
'utils'
,
'wsj0mix'
]
examples/source_separation/utils/dataset/utils.py
View file @
dc0990c7
...
...
@@ -2,9 +2,10 @@ from typing import List
from
functools
import
partial
from
collections
import
namedtuple
from
torchaudio.datasets
import
LibriMix
import
torch
from
.
import
wsj0mix
,
librimix
from
.
import
wsj0mix
Batch
=
namedtuple
(
"Batch"
,
[
"mix"
,
"src"
,
"mask"
])
...
...
@@ -15,9 +16,9 @@ def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, li
validation
=
wsj0mix
.
WSJ0Mix
(
root_dir
/
"cv"
,
num_speakers
,
sample_rate
)
evaluation
=
wsj0mix
.
WSJ0Mix
(
root_dir
/
"tt"
,
num_speakers
,
sample_rate
)
elif
dataset_type
==
"librimix"
:
train
=
librimix
.
LibriMix
(
root_dir
/
librimix_tr_split
,
num_speakers
,
sample_rate
,
task
)
validation
=
librimix
.
LibriMix
(
root_dir
/
"dev"
,
num_speakers
,
sample_rate
,
task
)
evaluation
=
librimix
.
LibriMix
(
root_dir
/
"test"
,
num_speakers
,
sample_rate
,
task
)
train
=
LibriMix
(
root_dir
,
librimix_tr_split
,
num_speakers
,
sample_rate
,
task
)
validation
=
LibriMix
(
root_dir
,
"dev"
,
num_speakers
,
sample_rate
,
task
)
evaluation
=
LibriMix
(
root_dir
,
"test"
,
num_speakers
,
sample_rate
,
task
)
else
:
raise
ValueError
(
f
"Unexpected dataset:
{
dataset_type
}
"
)
return
train
,
validation
,
evaluation
...
...
torchaudio/datasets/__init__.py
View file @
dc0990c7
...
...
@@ -8,6 +8,7 @@ from .yesno import YESNO
from
.ljspeech
import
LJSPEECH
from
.cmuarctic
import
CMUARCTIC
from
.cmudict
import
CMUDict
from
.librimix
import
LibriMix
from
.libritts
import
LIBRITTS
from
.tedlium
import
TEDLIUM
...
...
@@ -23,6 +24,7 @@ __all__ = [
"GTZAN"
,
"CMUARCTIC"
,
"CMUDict"
,
"LibriMix"
,
"LIBRITTS"
,
"diskcache_iterator"
,
"bg_iterator"
,
...
...
examples/source_separation/utils
/dataset/librimix.py
→
torchaudio
/dataset
s
/librimix.py
View file @
dc0990c7
...
...
@@ -13,27 +13,43 @@ class LibriMix(Dataset):
r
"""Create the LibriMix dataset.
Args:
root (str or Path): the path to the directory where the dataset is stored.
root (str or Path): The path to the directory where the directory ``Libri2Mix`` or
``Libri3Mix`` is stored.
subset (str, optional): The subset to use. Options: [``train-360`, ``train-100``,
``dev``, and ``test``] (Default: ``train-360``).
num_speakers (int, optional): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios. (Default: 2)
sample_rate (int, optional): sample rate of audio files. If any of the audio has a
different sample rate, raises ``ValueError``. (Default: 8000)
sample_rate (int, optional): sample rate of audio files. The ``sample_rate`` determines
which subdirectory the audio are fetched. If any of the audio has a different sample
rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)
task (str, optional): the task of LibriMix.
Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``]
(Default: ``sep_clean``)
Note:
The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
"""
def
__init__
(
self
,
root
:
Union
[
str
,
Path
],
subset
:
str
=
"train-360"
,
num_speakers
:
int
=
2
,
sample_rate
:
int
=
8000
,
task
:
str
=
"sep_clean"
,
):
self
.
root
=
Path
(
root
)
self
.
root
=
Path
(
root
)
/
f
"Libri
{
num_speakers
}
Mix"
if
sample_rate
==
8000
:
self
.
root
=
self
.
root
/
"wav8k/min"
/
subset
elif
sample_rate
==
16000
:
self
.
root
=
self
.
root
/
"wav16k/min"
/
subset
else
:
raise
ValueError
(
f
"Unsupported sample rate. Found
{
sample_rate
}
."
)
self
.
sample_rate
=
sample_rate
self
.
task
=
task
self
.
mix_dir
=
(
self
.
root
/
"mix_{
}"
.
format
(
task
.
split
(
'_'
)[
1
]
)
).
resolve
()
self
.
mix_dir
=
(
self
.
root
/
f
"mix_
{
task
.
split
(
'_'
)[
1
]
}
"
).
resolve
()
self
.
src_dirs
=
[(
self
.
root
/
f
"s
{
i
+
1
}
"
).
resolve
()
for
i
in
range
(
num_speakers
)]
self
.
files
=
[
p
.
name
for
p
in
self
.
mix_dir
.
glob
(
"*wav"
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment