Commit a63629b6 authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Refactor LibriSpeech Lightning datamodule to accommodate different dataset implementations (#2437)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2437

Refactors LibriSpeech Lightning datamodule to accommodate different dataset implementations.

Reviewed By: carolineechen, nateanl

Differential Revision: D36731577

fbshipit-source-id: 4ba91044311fa3f99a928aef6ef411316955f6b5
parent 877a88c5
......@@ -103,6 +103,8 @@ class TransformDataset(torch.utils.data.Dataset):
class LibriSpeechDataModule(LightningDataModule):
librispeech_cls = torchaudio.datasets.LIBRISPEECH
def __init__(
self,
*,
......@@ -116,6 +118,7 @@ class LibriSpeechDataModule(LightningDataModule):
train_shuffle=True,
num_workers=10,
):
super().__init__()
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
......@@ -130,9 +133,9 @@ class LibriSpeechDataModule(LightningDataModule):
def train_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
self.librispeech_cls(self.librispeech_path, url="train-clean-360"),
self.librispeech_cls(self.librispeech_path, url="train-clean-100"),
self.librispeech_cls(self.librispeech_path, url="train-other-500"),
]
if not self.train_dataset_lengths:
......@@ -161,8 +164,8 @@ class LibriSpeechDataModule(LightningDataModule):
def val_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
self.librispeech_cls(self.librispeech_path, url="dev-clean"),
self.librispeech_cls(self.librispeech_path, url="dev-other"),
]
if not self.val_dataset_lengths:
......@@ -185,7 +188,7 @@ class LibriSpeechDataModule(LightningDataModule):
return dataloader
def test_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean")
dataset = self.librispeech_cls(self.librispeech_path, url="test-clean")
dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment