Unverified Commit a4d643ea authored by Krishna Kalyan's avatar Krishna Kalyan Committed by GitHub
Browse files

Refactor LJSPEECH dataset (#1143)


Co-authored-by: default avatarkrishnakalyan3 <skalyan@cloudera.com>
parent b33c539c
...@@ -8,31 +8,15 @@ from torchaudio.datasets.utils import download_url, extract_archive ...@@ -8,31 +8,15 @@ from torchaudio.datasets.utils import download_url, extract_archive
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
URL = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2" _RELEASE_CONFIGS = {
FOLDER_IN_ARCHIVE = "wavs" "release1": {
_CHECKSUMS = { "folder_in_archive": "wavs",
"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2": "url": "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",
"be1a30453f28eb8dd26af4101ae40cbf2c50413b1bb21936cbcdc6fae3de8aa5" "checksum": "be1a30453f28eb8dd26af4101ae40cbf2c50413b1bb21936cbcdc6fae3de8aa5",
}
} }
def load_ljspeech_item(line: List[str], path: str, ext_audio: str) -> Tuple[Tensor, int, str, str]:
assert len(line) == 3
fileid, transcript, normalized_transcript = line
fileid_audio = fileid + ext_audio
fileid_audio = os.path.join(path, fileid_audio)
# Load audio
waveform, sample_rate = torchaudio.load(fileid_audio)
return (
waveform,
sample_rate,
transcript,
normalized_transcript,
)
class LJSPEECH(Dataset): class LJSPEECH(Dataset):
"""Create a Dataset for LJSpeech-1.1. """Create a Dataset for LJSpeech-1.1.
...@@ -46,37 +30,36 @@ class LJSPEECH(Dataset): ...@@ -46,37 +30,36 @@ class LJSPEECH(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``). Whether to download the dataset if it is not found at root path. (default: ``False``).
""" """
_ext_audio = ".wav"
_ext_archive = '.tar.bz2'
def __init__(self, def __init__(self,
root: Union[str, Path], root: Union[str, Path],
url: str = URL, url: str = _RELEASE_CONFIGS["release1"]["url"],
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
download: bool = False) -> None: download: bool = False) -> None:
# Get string representation of 'root' in case Path object is passed self._parse_filesystem(root, url, folder_in_archive, download)
root = os.fspath(root)
def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
root = Path(root)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(root, basename) archive = root / basename
basename = basename.split(self._ext_archive)[0] basename = Path(basename.split(".tar.bz2")[0])
folder_in_archive = os.path.join(basename, folder_in_archive) folder_in_archive = basename / folder_in_archive
self._path = os.path.join(root, folder_in_archive) self._path = root / folder_in_archive
self._metadata_path = os.path.join(root, basename, 'metadata.csv') self._metadata_path = root / basename / 'metadata.csv'
if download: if download:
if not os.path.isdir(self._path): if not os.path.isdir(self._path):
if not os.path.isfile(archive): if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None) checksum = _RELEASE_CONFIGS["release1"]["checksum"]
download_url(url, root, hash_value=checksum) download_url(url, root, hash_value=checksum)
extract_archive(archive) extract_archive(archive)
with open(self._metadata_path, "r", newline='') as metadata: with open(self._metadata_path, "r", newline='') as metadata:
walker = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE) flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE)
self._walker = list(walker) self._flist = list(flist)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -87,8 +70,19 @@ class LJSPEECH(Dataset): ...@@ -87,8 +70,19 @@ class LJSPEECH(Dataset):
Returns: Returns:
tuple: ``(waveform, sample_rate, transcript, normalized_transcript)`` tuple: ``(waveform, sample_rate, transcript, normalized_transcript)``
""" """
line = self._walker[n] line = self._flist[n]
return load_ljspeech_item(line, self._path, self._ext_audio) fileid, transcript, normalized_transcript = line
fileid_audio = self._path / (fileid + ".wav")
# Load audio
waveform, sample_rate = torchaudio.load(fileid_audio)
return (
waveform,
sample_rate,
transcript,
normalized_transcript,
)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._walker) return len(self._flist)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment