Unverified Commit 3695a0ef authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Datasets inline typing (#511)



* add CommonDataset Inline typing

* inline Typing librispeech

* add inline typing ljspeech

* add inline typing speechcommands

* add inline typing to vctk

* add inline typing yesno

* apply type to __getitem__
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 04e68471
import os
from torch.utils.data import Dataset
from typing import List, Dict, Tuple
import torchaudio
from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader
from torch import Tensor
from torch.utils.data import Dataset
# Default TSV should be one of
# dev.tsv
......@@ -19,7 +20,10 @@ VERSION = "cv-corpus-4-2019-12-10"
TSV = "train.tsv"
def load_commonvoice_item(line, header, path, folder_audio):
def load_commonvoice_item(line: List[str],
header: List[str],
path: str,
folder_audio: str) -> Tuple[Tensor, int, Dict[str, str]]:
# Each line as the following data:
# client_id, path, sentence, up_votes, down_votes, age, gender, accent
......@@ -47,12 +51,13 @@ class COMMONVOICE(Dataset):
_ext_audio = ".mp3"
_folder_audio = "clips"
def __init__(self, root,
tsv=TSV,
url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
version=VERSION,
download=False):
def __init__(self,
root: str,
tsv: str = TSV,
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
version: str = VERSION,
download: bool = False) -> None:
languages = {
"tatar": "tt",
......@@ -125,9 +130,9 @@ class COMMONVOICE(Dataset):
self._header = next(walker)
self._walker = list(walker)
def __getitem__(self, n):
def __getitem__(self, n: int) -> Tuple[Tensor, int, Dict[str, str]]:
line = self._walker[n]
return load_commonvoice_item(line, self._header, self._path, self._folder_audio)
def __len__(self):
def __len__(self) -> int:
return len(self._walker)
import os
from typing import Tuple
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
unicode_csv_reader,
walk_files,
)
......@@ -13,8 +14,10 @@ URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech"
def load_librispeech_item(fileid, path, ext_audio, ext_txt):
def load_librispeech_item(fileid: str,
path: str,
ext_audio: str,
ext_txt: str) -> Tuple[Tensor, int, str, int, int, int]:
speaker_id, chapter_id, utterance_id = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + ext_txt
......@@ -56,9 +59,11 @@ class LIBRISPEECH(Dataset):
_ext_txt = ".trans.txt"
_ext_audio = ".flac"
def __init__(
self, root, url=URL, folder_in_archive=FOLDER_IN_ARCHIVE, download=False
):
def __init__(self,
root: str,
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
if url in [
"dev-clean",
......@@ -94,9 +99,9 @@ class LIBRISPEECH(Dataset):
)
self._walker = list(walker)
def __getitem__(self, n):
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
fileid = self._walker[n]
return load_librispeech_item(fileid, self._path, self._ext_audio, self._ext_txt)
def __len__(self):
def __len__(self) -> int:
return len(self._walker)
import os
import csv
from typing import List, Tuple
import torchaudio
from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader
from torch import Tensor
from torch.utils.data import Dataset
URL = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
FOLDER_IN_ARCHIVE = "wavs"
def load_ljspeech_item(line, path, ext_audio):
def load_ljspeech_item(line: List[str], path: str, ext_audio: str) -> Tuple[Tensor, int, str, str]:
assert len(line) == 3
fileid, transcript, normalized_transcript = line
fileid_audio = fileid + ext_audio
......@@ -35,9 +37,11 @@ class LJSPEECH(Dataset):
_ext_audio = ".wav"
_ext_archive = '.tar.bz2'
def __init__(
self, root, url=URL, folder_in_archive=FOLDER_IN_ARCHIVE, download=False
):
def __init__(self,
root: str,
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
basename = os.path.basename(url)
archive = os.path.join(root, basename)
......@@ -58,9 +62,9 @@ class LJSPEECH(Dataset):
walker = unicode_csv_reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE)
self._walker = list(walker)
def __getitem__(self, n):
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
line = self._walker[n]
return load_ljspeech_item(line, self._path, self._ext_audio)
def __len__(self):
def __len__(self) -> int:
return len(self._walker)
import os
from typing import Tuple
import torchaudio
from torch.utils.data import Dataset
from torch import Tensor
from torchaudio.datasets.utils import (
download_url,
extract_archive,
......@@ -14,7 +16,7 @@ HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_"
def load_speechcommands_item(filepath, path):
def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]:
relpath = os.path.relpath(filepath, path)
label, filename = os.path.split(relpath)
speaker, _ = os.path.splitext(filename)
......@@ -33,13 +35,11 @@ class SPEECHCOMMANDS(Dataset):
waveform, sample_rate, label, speaker_id, utterance_number
"""
def __init__(
self,
root,
url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
download=False
):
def __init__(self,
root: str,
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
if url in [
"speech_commands_v0.01",
"speech_commands_v0.02",
......@@ -67,9 +67,9 @@ class SPEECHCOMMANDS(Dataset):
walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker)
self._walker = list(walker)
def __getitem__(self, n):
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
fileid = self._walker[n]
return load_speechcommands_item(fileid, self._path)
def __len__(self):
def __len__(self) -> int:
return len(self._walker)
import os
import warnings
from typing import Any, Tuple
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive, walk_files
......@@ -9,9 +11,13 @@ URL = "http://homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"
FOLDER_IN_ARCHIVE = "VCTK-Corpus"
def load_vctk_item(
fileid, path, ext_audio, ext_txt, folder_audio, folder_txt, downsample=False
):
def load_vctk_item(fileid: str,
path: str,
ext_audio: str,
ext_txt: str,
folder_audio: str,
folder_txt: str,
downsample: bool = False) -> Tuple[Tensor, int, str, str, str]:
speaker_id, utterance_id = fileid.split("_")
# Read text
......@@ -50,16 +56,14 @@ class VCTK(Dataset):
_ext_audio = ".wav"
_except_folder = "p315"
def __init__(
self,
root,
url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
download=False,
downsample=False,
transform=None,
target_transform=None,
):
def __init__(self,
root: str,
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
downsample: bool = False,
transform: Any = None,
target_transform: Any = None) -> None:
if downsample:
warnings.warn(
......@@ -100,7 +104,7 @@ class VCTK(Dataset):
walker = filter(lambda w: self._except_folder not in w, walker)
self._walker = list(walker)
def __getitem__(self, n):
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str]:
fileid = self._walker[n]
item = load_vctk_item(
fileid,
......@@ -121,5 +125,5 @@ class VCTK(Dataset):
utterance = self.target_transform(utterance)
return waveform, sample_rate, utterance, speaker_id, utterance_id
def __len__(self):
def __len__(self) -> int:
return len(self._walker)
import os
import warnings
from typing import Any, List, Tuple
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive, walk_files
......@@ -9,7 +11,7 @@ URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
FOLDER_IN_ARCHIVE = "waves_yesno"
def load_yesno_item(fileid, path, ext_audio):
def load_yesno_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, int, List[int]]:
# Read label
labels = [int(c) for c in fileid.split("_")]
......@@ -28,15 +30,13 @@ class YESNO(Dataset):
_ext_audio = ".wav"
def __init__(
self,
root,
url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
download=False,
transform=None,
target_transform=None,
):
def __init__(self,
root: str,
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
transform: Any = None,
target_transform: Any = None) -> None:
if transform is not None or target_transform is not None:
warnings.warn(
......@@ -68,7 +68,7 @@ class YESNO(Dataset):
)
self._walker = list(walker)
def __getitem__(self, n):
def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
fileid = self._walker[n]
item = load_yesno_item(fileid, self._path, self._ext_audio)
......@@ -82,5 +82,5 @@ class YESNO(Dataset):
labels = self.target_transform(labels)
return waveform, sample_rate, labels
def __len__(self):
def __len__(self) -> int:
return len(self._walker)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment