Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
870811c7
Unverified
Commit
870811c7
authored
Jul 30, 2020
by
jimchen90
Committed by
GitHub
Jul 30, 2020
Browse files
Add libritts dataset option (#818)
Co-authored-by:
Ji Chen
<
jimchen90@devfair0160.h2.fair
>
parent
1ecbc249
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
8 deletions
+23
-8
examples/pipeline_wavernn/datasets.py
examples/pipeline_wavernn/datasets.py
+14
-6
examples/pipeline_wavernn/main.py
examples/pipeline_wavernn/main.py
+9
-2
No files found.
examples/pipeline_wavernn/datasets.py
View file @
870811c7
...
@@ -4,7 +4,7 @@ import random
...
@@ -4,7 +4,7 @@ import random
import
torch
import
torch
import
torchaudio
import
torchaudio
from
torch.utils.data.dataset
import
random_split
from
torch.utils.data.dataset
import
random_split
from
torchaudio.datasets
import
LJSPEECH
from
torchaudio.datasets
import
LJSPEECH
,
LIBRITTS
from
torchaudio.transforms
import
MuLawEncoding
from
torchaudio.transforms
import
MuLawEncoding
from
processing
import
bits_to_normalized_waveform
,
normalized_waveform_to_bits
from
processing
import
bits_to_normalized_waveform
,
normalized_waveform_to_bits
...
@@ -48,12 +48,20 @@ class Processed(torch.utils.data.Dataset):
...
@@ -48,12 +48,20 @@ class Processed(torch.utils.data.Dataset):
return
item
[
0
].
squeeze
(
0
),
specgram
return
item
[
0
].
squeeze
(
0
),
specgram
def
split_process_ljspeech
(
args
,
transforms
):
def
split_process_dataset
(
args
,
transforms
):
data
=
LJSPEECH
(
root
=
args
.
file_path
,
download
=
False
)
if
args
.
dataset
==
'ljspeech'
:
data
=
LJSPEECH
(
root
=
args
.
file_path
,
download
=
False
)
val_length
=
int
(
len
(
data
)
*
args
.
val_ratio
)
val_length
=
int
(
len
(
data
)
*
args
.
val_ratio
)
lengths
=
[
len
(
data
)
-
val_length
,
val_length
]
lengths
=
[
len
(
data
)
-
val_length
,
val_length
]
train_dataset
,
val_dataset
=
random_split
(
data
,
lengths
)
train_dataset
,
val_dataset
=
random_split
(
data
,
lengths
)
elif
args
.
dataset
==
'libritts'
:
train_dataset
=
LIBRITTS
(
root
=
args
.
file_path
,
url
=
'train-clean-100'
,
download
=
False
)
val_dataset
=
LIBRITTS
(
root
=
args
.
file_path
,
url
=
'dev-clean'
,
download
=
False
)
else
:
raise
ValueError
(
f
"Expected dataset: `ljspeech` or `libritts`, but found
{
args
.
dataset
}
"
)
train_dataset
=
Processed
(
train_dataset
,
transforms
)
train_dataset
=
Processed
(
train_dataset
,
transforms
)
val_dataset
=
Processed
(
val_dataset
,
transforms
)
val_dataset
=
Processed
(
val_dataset
,
transforms
)
...
...
examples/pipeline_wavernn/main.py
View file @
870811c7
...
@@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
...
@@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
from
torchaudio.datasets.utils
import
bg_iterator
from
torchaudio.datasets.utils
import
bg_iterator
from
torchaudio.models.wavernn
import
WaveRNN
from
torchaudio.models.wavernn
import
WaveRNN
from
datasets
import
collate_factory
,
split_process_
ljspeech
from
datasets
import
collate_factory
,
split_process_
dataset
from
losses
import
LongCrossEntropyLoss
,
MoLLoss
from
losses
import
LongCrossEntropyLoss
,
MoLLoss
from
processing
import
LinearToMel
,
NormalizeDB
from
processing
import
LinearToMel
,
NormalizeDB
from
utils
import
MetricLogger
,
count_parameters
,
save_checkpoint
from
utils
import
MetricLogger
,
count_parameters
,
save_checkpoint
...
@@ -55,6 +55,13 @@ def parse_args():
...
@@ -55,6 +55,13 @@ def parse_args():
metavar
=
"N"
,
metavar
=
"N"
,
help
=
"print frequency in epochs"
,
help
=
"print frequency in epochs"
,
)
)
parser
.
add_argument
(
"--dataset"
,
default
=
"ljspeech"
,
choices
=
[
"ljspeech"
,
"libritts"
],
type
=
str
,
help
=
"select dataset to train with"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batch-size"
,
default
=
256
,
type
=
int
,
metavar
=
"N"
,
help
=
"mini-batch size"
"--batch-size"
,
default
=
256
,
type
=
int
,
metavar
=
"N"
,
help
=
"mini-batch size"
)
)
...
@@ -269,7 +276,7 @@ def main(args):
...
@@ -269,7 +276,7 @@ def main(args):
NormalizeDB
(
min_level_db
=
args
.
min_level_db
),
NormalizeDB
(
min_level_db
=
args
.
min_level_db
),
)
)
train_dataset
,
val_dataset
=
split_process_
ljspeech
(
args
,
transforms
)
train_dataset
,
val_dataset
=
split_process_
dataset
(
args
,
transforms
)
loader_training_params
=
{
loader_training_params
=
{
"num_workers"
:
args
.
workers
,
"num_workers"
:
args
.
workers
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment