Commit a63629b6 authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Refactor LibriSpeech Lightning datamodule to accommodate different dataset implementations (#2437)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2437

Refactors LibriSpeech Lightning datamodule to accommodate different dataset implementations.

Reviewed By: carolineechen, nateanl

Differential Revision: D36731577

fbshipit-source-id: 4ba91044311fa3f99a928aef6ef411316955f6b5
parent 877a88c5
...@@ -103,6 +103,8 @@ class TransformDataset(torch.utils.data.Dataset): ...@@ -103,6 +103,8 @@ class TransformDataset(torch.utils.data.Dataset):
class LibriSpeechDataModule(LightningDataModule): class LibriSpeechDataModule(LightningDataModule):
librispeech_cls = torchaudio.datasets.LIBRISPEECH
def __init__( def __init__(
self, self,
*, *,
...@@ -116,6 +118,7 @@ class LibriSpeechDataModule(LightningDataModule): ...@@ -116,6 +118,7 @@ class LibriSpeechDataModule(LightningDataModule):
train_shuffle=True, train_shuffle=True,
num_workers=10, num_workers=10,
): ):
super().__init__()
self.librispeech_path = librispeech_path self.librispeech_path = librispeech_path
self.train_dataset_lengths = None self.train_dataset_lengths = None
self.val_dataset_lengths = None self.val_dataset_lengths = None
...@@ -130,9 +133,9 @@ class LibriSpeechDataModule(LightningDataModule): ...@@ -130,9 +133,9 @@ class LibriSpeechDataModule(LightningDataModule):
def train_dataloader(self): def train_dataloader(self):
datasets = [ datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"), self.librispeech_cls(self.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"), self.librispeech_cls(self.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"), self.librispeech_cls(self.librispeech_path, url="train-other-500"),
] ]
if not self.train_dataset_lengths: if not self.train_dataset_lengths:
...@@ -161,8 +164,8 @@ class LibriSpeechDataModule(LightningDataModule): ...@@ -161,8 +164,8 @@ class LibriSpeechDataModule(LightningDataModule):
def val_dataloader(self): def val_dataloader(self):
datasets = [ datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"), self.librispeech_cls(self.librispeech_path, url="dev-clean"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"), self.librispeech_cls(self.librispeech_path, url="dev-other"),
] ]
if not self.val_dataset_lengths: if not self.val_dataset_lengths:
...@@ -185,7 +188,7 @@ class LibriSpeechDataModule(LightningDataModule): ...@@ -185,7 +188,7 @@ class LibriSpeechDataModule(LightningDataModule):
return dataloader return dataloader
def test_dataloader(self): def test_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean") dataset = self.librispeech_cls(self.librispeech_path, url="test-clean")
dataset = TransformDataset(dataset, self.test_transform) dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None) dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader return dataloader
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment