Unverified Commit 870811c7 authored by jimchen90's avatar jimchen90 Committed by GitHub
Browse files

Add libritts dataset option (#818)


Co-authored-by: default avatarJi Chen <jimchen90@devfair0160.h2.fair>
parent 1ecbc249
......@@ -4,7 +4,7 @@ import random
import torch
import torchaudio
from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH
from torchaudio.datasets import LJSPEECH, LIBRITTS
from torchaudio.transforms import MuLawEncoding
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
......@@ -48,12 +48,20 @@ class Processed(torch.utils.data.Dataset):
return item[0].squeeze(0), specgram
def split_process_ljspeech(args, transforms):
data = LJSPEECH(root=args.file_path, download=False)
def split_process_dataset(args, transforms):
if args.dataset == 'ljspeech':
data = LJSPEECH(root=args.file_path, download=False)
val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths)
val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths)
elif args.dataset == 'libritts':
train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False)
val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False)
else:
raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}")
train_dataset = Processed(train_dataset, transforms)
val_dataset = Processed(val_dataset, transforms)
......
......@@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models.wavernn import WaveRNN
from datasets import collate_factory, split_process_ljspeech
from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import LinearToMel, NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint
......@@ -55,6 +55,13 @@ def parse_args():
metavar="N",
help="print frequency in epochs",
)
parser.add_argument(
"--dataset",
default="ljspeech",
choices=["ljspeech", "libritts"],
type=str,
help="select dataset to train with",
)
parser.add_argument(
"--batch-size", default=256, type=int, metavar="N", help="mini-batch size"
)
......@@ -269,7 +276,7 @@ def main(args):
NormalizeDB(min_level_db=args.min_level_db),
)
train_dataset, val_dataset = split_process_ljspeech(args, transforms)
train_dataset, val_dataset = split_process_dataset(args, transforms)
loader_training_params = {
"num_workers": args.workers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment