Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
d857348f
Unverified
Commit
d857348f
authored
Oct 07, 2021
by
nateanl
Committed by
GitHub
Oct 07, 2021
Browse files
[Cherry-picked 0.10] Move LibriMix dataset to datasets directory (#1833)
parent
f9663a7b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
59 additions
and
32 deletions
+59
-32
examples/source_separation/eval.py
examples/source_separation/eval.py
+11
-7
examples/source_separation/lightning_train.py
examples/source_separation/lightning_train.py
+18
-14
examples/source_separation/utils/dataset/__init__.py
examples/source_separation/utils/dataset/__init__.py
+2
-2
examples/source_separation/utils/dataset/utils.py
examples/source_separation/utils/dataset/utils.py
+5
-4
torchaudio/datasets/__init__.py
torchaudio/datasets/__init__.py
+2
-0
torchaudio/datasets/librimix.py
torchaudio/datasets/librimix.py
+21
-5
No files found.
examples/source_separation/eval.py
View file @
d857348f
from
argparse
import
ArgumentParser
from
argparse
import
ArgumentParser
import
p
ath
lib
from
pathlib
import
P
ath
from
lightning_train
import
_get_model
,
_get_dataloader
,
sisdri_metric
from
lightning_train
import
_get_model
,
_get_dataloader
,
sisdri_metric
import
mir_eval
import
mir_eval
import
torch
import
torch
def
eval
(
model
,
data_loader
,
device
):
def
_
eval
(
model
,
data_loader
,
device
):
results
=
torch
.
zeros
(
4
)
results
=
torch
.
zeros
(
4
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
i
,
batch
in
enumerate
(
data_loader
):
for
_
,
batch
in
enumerate
(
data_loader
):
mix
,
src
,
mask
=
batch
mix
,
src
,
mask
=
batch
mix
,
src
,
mask
=
mix
.
to
(
device
),
src
.
to
(
device
),
mask
.
to
(
device
)
mix
,
src
,
mask
=
mix
.
to
(
device
),
src
.
to
(
device
),
mask
.
to
(
device
)
est
=
model
(
mix
)
est
=
model
(
mix
)
...
@@ -35,7 +35,11 @@ def eval(model, data_loader, device):
...
@@ -35,7 +35,11 @@ def eval(model, data_loader, device):
def
cli_main
():
def
cli_main
():
parser
=
ArgumentParser
()
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--dataset"
,
default
=
"librimix"
,
type
=
str
,
choices
=
[
"wsj0-mix"
,
"librimix"
])
parser
.
add_argument
(
"--dataset"
,
default
=
"librimix"
,
type
=
str
,
choices
=
[
"wsj0-mix"
,
"librimix"
])
parser
.
add_argument
(
"--data-dir"
,
default
=
pathlib
.
Path
(
"./Libri2Mix/wav8k/min"
),
type
=
pathlib
.
Path
)
parser
.
add_argument
(
"--root-dir"
,
type
=
Path
,
help
=
"The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--librimix-tr-split"
,
"--librimix-tr-split"
,
default
=
"train-360"
,
default
=
"train-360"
,
...
@@ -60,8 +64,8 @@ def cli_main():
...
@@ -60,8 +64,8 @@ def cli_main():
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--exp-dir"
,
"--exp-dir"
,
default
=
pathlib
.
Path
(
"./exp"
),
default
=
Path
(
"./exp"
),
type
=
pathlib
.
Path
,
type
=
Path
,
help
=
"The directory to save checkpoints and logs."
help
=
"The directory to save checkpoints and logs."
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -95,7 +99,7 @@ def cli_main():
...
@@ -95,7 +99,7 @@ def cli_main():
args
.
librimix_tr_split
,
args
.
librimix_tr_split
,
)
)
eval
(
model
,
eval_loader
,
device
)
_
eval
(
model
,
eval_loader
,
device
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/source_separation/lightning_train.py
View file @
d857348f
#!/usr/bin/env python3
#!/usr/bin/env python3
# pyre-strict
# pyre-strict
from
pathlib
import
Path
import
pathlib
from
argparse
import
ArgumentParser
from
argparse
import
ArgumentParser
from
typing
import
(
from
typing
import
(
Any
,
Any
,
...
@@ -13,6 +12,7 @@ from typing import (
...
@@ -13,6 +12,7 @@ from typing import (
Optional
,
Optional
,
Tuple
,
Tuple
,
TypedDict
,
TypedDict
,
Union
,
)
)
import
torch
import
torch
...
@@ -279,7 +279,7 @@ def _get_model(
...
@@ -279,7 +279,7 @@ def _get_model(
def
_get_dataloader
(
def
_get_dataloader
(
dataset_type
:
str
,
dataset_type
:
str
,
datase
t_dir
:
pathlib
.
Path
,
roo
t_dir
:
Union
[
str
,
Path
]
,
num_speakers
:
int
=
2
,
num_speakers
:
int
=
2
,
sample_rate
:
int
=
8000
,
sample_rate
:
int
=
8000
,
batch_size
:
int
=
6
,
batch_size
:
int
=
6
,
...
@@ -291,11 +291,11 @@ def _get_dataloader(
...
@@ -291,11 +291,11 @@ def _get_dataloader(
Args:
Args:
dataset_type (str): the dataset to use.
dataset_type (str): the dataset to use.
datase
t_dir (
pathlib.
Path): the root directory of the dataset.
roo
t_dir (
str or
Path): the root directory of the dataset.
num_speakers (int): the number of speakers in the mixture. (Default: 2)
num_speakers (int
, optional
): the number of speakers in the mixture. (Default: 2)
sample_rate (int): the sample rate of the audio. (Default: 8000)
sample_rate (int
, optional
): the sample rate of the audio. (Default: 8000)
batch_size (int): the batch size of the dataset. (Default: 6)
batch_size (int
, optional
): the batch size of the dataset. (Default: 6)
num_workers (int): the number of workers for each dataloader. (Default: 4)
num_workers (int
, optional
): the number of workers for each dataloader. (Default: 4)
librimix_task (str or None, optional): the task in LibriMix dataset.
librimix_task (str or None, optional): the task in LibriMix dataset.
librimix_tr_split (str or None, optional): the training split in LibriMix dataset.
librimix_tr_split (str or None, optional): the training split in LibriMix dataset.
...
@@ -303,7 +303,7 @@ def _get_dataloader(
...
@@ -303,7 +303,7 @@ def _get_dataloader(
tuple: (train_loader, valid_loader, eval_loader)
tuple: (train_loader, valid_loader, eval_loader)
"""
"""
train_dataset
,
valid_dataset
,
eval_dataset
=
dataset_utils
.
get_dataset
(
train_dataset
,
valid_dataset
,
eval_dataset
=
dataset_utils
.
get_dataset
(
dataset_type
,
datase
t_dir
,
num_speakers
,
sample_rate
,
librimix_task
,
librimix_tr_split
dataset_type
,
roo
t_dir
,
num_speakers
,
sample_rate
,
librimix_task
,
librimix_tr_split
)
)
train_collate_fn
=
dataset_utils
.
get_collate_fn
(
train_collate_fn
=
dataset_utils
.
get_collate_fn
(
dataset_type
,
mode
=
'train'
,
sample_rate
=
sample_rate
,
duration
=
3
dataset_type
,
mode
=
'train'
,
sample_rate
=
sample_rate
,
duration
=
3
...
@@ -337,9 +337,13 @@ def _get_dataloader(
...
@@ -337,9 +337,13 @@ def _get_dataloader(
def
cli_main
():
def
cli_main
():
parser
=
ArgumentParser
()
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--batch-size"
,
default
=
3
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
6
,
type
=
int
)
parser
.
add_argument
(
"--dataset"
,
default
=
"librimix"
,
type
=
str
,
choices
=
[
"wsj0-mix"
,
"librimix"
])
parser
.
add_argument
(
"--dataset"
,
default
=
"librimix"
,
type
=
str
,
choices
=
[
"wsj0-mix"
,
"librimix"
])
parser
.
add_argument
(
"--data-dir"
,
default
=
pathlib
.
Path
(
"./Libri2Mix/wav8k/min"
),
type
=
pathlib
.
Path
)
parser
.
add_argument
(
"--root-dir"
,
type
=
Path
,
help
=
"The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--librimix-tr-split"
,
"--librimix-tr-split"
,
default
=
"train-360"
,
default
=
"train-360"
,
...
@@ -364,8 +368,8 @@ def cli_main():
...
@@ -364,8 +368,8 @@ def cli_main():
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--exp-dir"
,
"--exp-dir"
,
default
=
pathlib
.
Path
(
"./exp"
),
default
=
Path
(
"./exp"
),
type
=
pathlib
.
Path
,
type
=
Path
,
help
=
"The directory to save checkpoints and logs."
help
=
"The directory to save checkpoints and logs."
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -404,7 +408,7 @@ def cli_main():
...
@@ -404,7 +408,7 @@ def cli_main():
)
)
train_loader
,
valid_loader
,
eval_loader
=
_get_dataloader
(
train_loader
,
valid_loader
,
eval_loader
=
_get_dataloader
(
args
.
dataset
,
args
.
dataset
,
args
.
data
_dir
,
args
.
root
_dir
,
args
.
num_speakers
,
args
.
num_speakers
,
args
.
sample_rate
,
args
.
sample_rate
,
args
.
batch_size
,
args
.
batch_size
,
...
...
examples/source_separation/utils/dataset/__init__.py
View file @
d857348f
from
.
import
utils
,
wsj0mix
,
librimix
from
.
import
utils
,
wsj0mix
__all__
=
[
'utils'
,
'wsj0mix'
,
'librimix'
]
__all__
=
[
'utils'
,
'wsj0mix'
]
examples/source_separation/utils/dataset/utils.py
View file @
d857348f
...
@@ -2,9 +2,10 @@ from typing import List
...
@@ -2,9 +2,10 @@ from typing import List
from
functools
import
partial
from
functools
import
partial
from
collections
import
namedtuple
from
collections
import
namedtuple
from
torchaudio.datasets
import
LibriMix
import
torch
import
torch
from
.
import
wsj0mix
,
librimix
from
.
import
wsj0mix
Batch
=
namedtuple
(
"Batch"
,
[
"mix"
,
"src"
,
"mask"
])
Batch
=
namedtuple
(
"Batch"
,
[
"mix"
,
"src"
,
"mask"
])
...
@@ -15,9 +16,9 @@ def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, li
...
@@ -15,9 +16,9 @@ def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, li
validation
=
wsj0mix
.
WSJ0Mix
(
root_dir
/
"cv"
,
num_speakers
,
sample_rate
)
validation
=
wsj0mix
.
WSJ0Mix
(
root_dir
/
"cv"
,
num_speakers
,
sample_rate
)
evaluation
=
wsj0mix
.
WSJ0Mix
(
root_dir
/
"tt"
,
num_speakers
,
sample_rate
)
evaluation
=
wsj0mix
.
WSJ0Mix
(
root_dir
/
"tt"
,
num_speakers
,
sample_rate
)
elif
dataset_type
==
"librimix"
:
elif
dataset_type
==
"librimix"
:
train
=
librimix
.
LibriMix
(
root_dir
/
librimix_tr_split
,
num_speakers
,
sample_rate
,
task
)
train
=
LibriMix
(
root_dir
,
librimix_tr_split
,
num_speakers
,
sample_rate
,
task
)
validation
=
librimix
.
LibriMix
(
root_dir
/
"dev"
,
num_speakers
,
sample_rate
,
task
)
validation
=
LibriMix
(
root_dir
,
"dev"
,
num_speakers
,
sample_rate
,
task
)
evaluation
=
librimix
.
LibriMix
(
root_dir
/
"test"
,
num_speakers
,
sample_rate
,
task
)
evaluation
=
LibriMix
(
root_dir
,
"test"
,
num_speakers
,
sample_rate
,
task
)
else
:
else
:
raise
ValueError
(
f
"Unexpected dataset:
{
dataset_type
}
"
)
raise
ValueError
(
f
"Unexpected dataset:
{
dataset_type
}
"
)
return
train
,
validation
,
evaluation
return
train
,
validation
,
evaluation
...
...
torchaudio/datasets/__init__.py
View file @
d857348f
...
@@ -8,6 +8,7 @@ from .yesno import YESNO
...
@@ -8,6 +8,7 @@ from .yesno import YESNO
from
.ljspeech
import
LJSPEECH
from
.ljspeech
import
LJSPEECH
from
.cmuarctic
import
CMUARCTIC
from
.cmuarctic
import
CMUARCTIC
from
.cmudict
import
CMUDict
from
.cmudict
import
CMUDict
from
.librimix
import
LibriMix
from
.libritts
import
LIBRITTS
from
.libritts
import
LIBRITTS
from
.tedlium
import
TEDLIUM
from
.tedlium
import
TEDLIUM
...
@@ -23,6 +24,7 @@ __all__ = [
...
@@ -23,6 +24,7 @@ __all__ = [
"GTZAN"
,
"GTZAN"
,
"CMUARCTIC"
,
"CMUARCTIC"
,
"CMUDict"
,
"CMUDict"
,
"LibriMix"
,
"LIBRITTS"
,
"LIBRITTS"
,
"TEDLIUM"
,
"TEDLIUM"
,
]
]
examples/source_separation/utils
/dataset/librimix.py
→
torchaudio
/dataset
s
/librimix.py
View file @
d857348f
...
@@ -13,27 +13,43 @@ class LibriMix(Dataset):
...
@@ -13,27 +13,43 @@ class LibriMix(Dataset):
r
"""Create the LibriMix dataset.
r
"""Create the LibriMix dataset.
Args:
Args:
root (str or Path): the path to the directory where the dataset is stored.
root (str or Path): The path to the directory where the directory ``Libri2Mix`` or
``Libri3Mix`` is stored.
subset (str, optional): The subset to use. Options: [``train-360`, ``train-100``,
``dev``, and ``test``] (Default: ``train-360``).
num_speakers (int, optional): The number of speakers, which determines the directories
num_speakers (int, optional): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios. (Default: 2)
N source audios. (Default: 2)
sample_rate (int, optional): sample rate of audio files. If any of the audio has a
sample_rate (int, optional): sample rate of audio files. The ``sample_rate`` determines
different sample rate, raises ``ValueError``. (Default: 8000)
which subdirectory the audio are fetched. If any of the audio has a different sample
rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)
task (str, optional): the task of LibriMix.
task (str, optional): the task of LibriMix.
Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``]
Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``]
(Default: ``sep_clean``)
(Default: ``sep_clean``)
Note:
The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
root
:
Union
[
str
,
Path
],
root
:
Union
[
str
,
Path
],
subset
:
str
=
"train-360"
,
num_speakers
:
int
=
2
,
num_speakers
:
int
=
2
,
sample_rate
:
int
=
8000
,
sample_rate
:
int
=
8000
,
task
:
str
=
"sep_clean"
,
task
:
str
=
"sep_clean"
,
):
):
self
.
root
=
Path
(
root
)
self
.
root
=
Path
(
root
)
/
f
"Libri
{
num_speakers
}
Mix"
if
sample_rate
==
8000
:
self
.
root
=
self
.
root
/
"wav8k/min"
/
subset
elif
sample_rate
==
16000
:
self
.
root
=
self
.
root
/
"wav16k/min"
/
subset
else
:
raise
ValueError
(
f
"Unsupported sample rate. Found
{
sample_rate
}
."
)
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
task
=
task
self
.
task
=
task
self
.
mix_dir
=
(
self
.
root
/
"mix_{
}"
.
format
(
task
.
split
(
'_'
)[
1
]
)
).
resolve
()
self
.
mix_dir
=
(
self
.
root
/
f
"mix_
{
task
.
split
(
'_'
)[
1
]
}
"
).
resolve
()
self
.
src_dirs
=
[(
self
.
root
/
f
"s
{
i
+
1
}
"
).
resolve
()
for
i
in
range
(
num_speakers
)]
self
.
src_dirs
=
[(
self
.
root
/
f
"s
{
i
+
1
}
"
).
resolve
()
for
i
in
range
(
num_speakers
)]
self
.
files
=
[
p
.
name
for
p
in
self
.
mix_dir
.
glob
(
"*wav"
)]
self
.
files
=
[
p
.
name
for
p
in
self
.
mix_dir
.
glob
(
"*wav"
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment