Unverified Commit 870811c7 authored by jimchen90's avatar jimchen90 Committed by GitHub
Browse files

Add libritts dataset option (#818)


Co-authored-by: default avatarJi Chen <jimchen90@devfair0160.h2.fair>
parent 1ecbc249
...@@ -4,7 +4,7 @@ import random ...@@ -4,7 +4,7 @@ import random
import torch import torch
import torchaudio import torchaudio
from torch.utils.data.dataset import random_split from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH from torchaudio.datasets import LJSPEECH, LIBRITTS
from torchaudio.transforms import MuLawEncoding from torchaudio.transforms import MuLawEncoding
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
...@@ -48,12 +48,20 @@ class Processed(torch.utils.data.Dataset): ...@@ -48,12 +48,20 @@ class Processed(torch.utils.data.Dataset):
return item[0].squeeze(0), specgram return item[0].squeeze(0), specgram
def split_process_ljspeech(args, transforms): def split_process_dataset(args, transforms):
data = LJSPEECH(root=args.file_path, download=False) if args.dataset == 'ljspeech':
data = LJSPEECH(root=args.file_path, download=False)
val_length = int(len(data) * args.val_ratio) val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length] lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths) train_dataset, val_dataset = random_split(data, lengths)
elif args.dataset == 'libritts':
train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False)
val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False)
else:
raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}")
train_dataset = Processed(train_dataset, transforms) train_dataset = Processed(train_dataset, transforms)
val_dataset = Processed(val_dataset, transforms) val_dataset = Processed(val_dataset, transforms)
......
...@@ -15,7 +15,7 @@ from torch.utils.data import DataLoader ...@@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator from torchaudio.datasets.utils import bg_iterator
from torchaudio.models.wavernn import WaveRNN from torchaudio.models.wavernn import WaveRNN
from datasets import collate_factory, split_process_ljspeech from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss from losses import LongCrossEntropyLoss, MoLLoss
from processing import LinearToMel, NormalizeDB from processing import LinearToMel, NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint from utils import MetricLogger, count_parameters, save_checkpoint
...@@ -55,6 +55,13 @@ def parse_args(): ...@@ -55,6 +55,13 @@ def parse_args():
metavar="N", metavar="N",
help="print frequency in epochs", help="print frequency in epochs",
) )
parser.add_argument(
"--dataset",
default="ljspeech",
choices=["ljspeech", "libritts"],
type=str,
help="select dataset to train with",
)
parser.add_argument( parser.add_argument(
"--batch-size", default=256, type=int, metavar="N", help="mini-batch size" "--batch-size", default=256, type=int, metavar="N", help="mini-batch size"
) )
...@@ -269,7 +276,7 @@ def main(args): ...@@ -269,7 +276,7 @@ def main(args):
NormalizeDB(min_level_db=args.min_level_db), NormalizeDB(min_level_db=args.min_level_db),
) )
train_dataset, val_dataset = split_process_ljspeech(args, transforms) train_dataset, val_dataset = split_process_dataset(args, transforms)
loader_training_params = { loader_training_params = {
"num_workers": args.workers, "num_workers": args.workers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment