Unverified Commit 38d1a9b6 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Data points remain tuples (#330)

* close file.

* staying with datapoints as tuples until further notice.

* loading tsv as dict.
parent ffcda23f
...@@ -15,11 +15,11 @@ class TestDatasets(unittest.TestCase): ...@@ -15,11 +15,11 @@ class TestDatasets(unittest.TestCase):
path = os.path.join(test_dirpath, "assets") path = os.path.join(test_dirpath, "assets")
def test_yesno(self): def test_yesno(self):
data = YESNO(self.path, return_dict=True) data = YESNO(self.path)
data[0] data[0]
def test_vctk(self): def test_vctk(self):
data = VCTK(self.path, return_dict=True) data = VCTK(self.path)
data[0] data[0]
def test_librispeech(self): def test_librispeech(self):
......
import os import os
import torchaudio
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torchaudio
from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader
# Default TSV should be one of # Default TSV should be one of
...@@ -17,6 +18,10 @@ TSV = "train.tsv" ...@@ -17,6 +18,10 @@ TSV = "train.tsv"
def load_commonvoice_item(line, header, path, folder_audio): def load_commonvoice_item(line, header, path, folder_audio):
# Each line as the following data:
# client_id, path, sentence, up_votes, down_votes, age, gender, accent
assert header[1] == "path"
fileid = line[1] fileid = line[1]
filename = os.path.join(path, folder_audio, fileid) filename = os.path.join(path, folder_audio, fileid)
...@@ -24,13 +29,17 @@ def load_commonvoice_item(line, header, path, folder_audio): ...@@ -24,13 +29,17 @@ def load_commonvoice_item(line, header, path, folder_audio):
waveform, sample_rate = torchaudio.load(filename) waveform, sample_rate = torchaudio.load(filename)
dic = dict(zip(header, line)) dic = dict(zip(header, line))
dic["waveform"] = waveform
dic["sample_rate"] = sample_rate
return dic return waveform, sample_rate, dic
class COMMONVOICE(Dataset): class COMMONVOICE(Dataset):
"""
Create a Dataset for CommonVoice. Each item is a tuple of the form:
(waveform, sample_rate, dictionary)
where dictionary is a dictionary built from the tsv file with the following keys:
client_id, path, sentence, up_votes, down_votes, age, gender, accent.
"""
_ext_txt = ".txt" _ext_txt = ".txt"
_ext_audio = ".mp3" _ext_audio = ".mp3"
......
import os import os
from torch.utils.data import Dataset
import torchaudio import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
download_url, download_url,
extract_archive, extract_archive,
...@@ -16,38 +15,43 @@ FOLDER_IN_ARCHIVE = "LibriSpeech" ...@@ -16,38 +15,43 @@ FOLDER_IN_ARCHIVE = "LibriSpeech"
def load_librispeech_item(fileid, path, ext_audio, ext_txt): def load_librispeech_item(fileid, path, ext_audio, ext_txt):
speaker, chapter, utterance = fileid.split("-") speaker_id, chapter_id, utterance_id = fileid.split("-")
file_text = speaker + "-" + chapter + ext_txt file_text = speaker_id + "-" + chapter_id + ext_txt
file_text = os.path.join(path, speaker, chapter, file_text) file_text = os.path.join(path, speaker_id, chapter_id, file_text)
fileid_audio = speaker + "-" + chapter + "-" + utterance fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id
file_audio = fileid_audio + ext_audio file_audio = fileid_audio + ext_audio
file_audio = os.path.join(path, speaker, chapter, file_audio) file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
# Load audio # Load audio
waveform, sample_rate = torchaudio.load(file_audio) waveform, sample_rate = torchaudio.load(file_audio)
# Load text # Load text
for line in open(file_text): with open(file_text) as ft:
fileid_text, content = line.strip().split(" ", 1) for line in ft:
if fileid_audio == fileid_text: fileid_text, utterance = line.strip().split(" ", 1)
break if fileid_audio == fileid_text:
else: break
# Translation not found else:
raise FileNotFoundError("Translation not found for " + fileid_audio) # Translation not found
raise FileNotFoundError("Translation not found for " + fileid_audio)
return {
"speaker_id": speaker, return (
"chapter_id": chapter, waveform,
"utterance_id": utterance, sample_rate,
"utterance": content, utterance,
"waveform": waveform, int(speaker_id),
"sample_rate": sample_rate, int(chapter_id),
} int(utterance_id),
)
class LIBRISPEECH(Dataset): class LIBRISPEECH(Dataset):
"""
Create a Dataset for LibriSpeech. Each item is a tuple of the form:
waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id
"""
_ext_txt = ".trans.txt" _ext_txt = ".trans.txt"
_ext_audio = ".flac" _ext_audio = ".flac"
......
...@@ -12,15 +12,15 @@ FOLDER_IN_ARCHIVE = "VCTK-Corpus" ...@@ -12,15 +12,15 @@ FOLDER_IN_ARCHIVE = "VCTK-Corpus"
def load_vctk_item( def load_vctk_item(
fileid, path, ext_audio, ext_txt, folder_audio, folder_txt, downsample=False fileid, path, ext_audio, ext_txt, folder_audio, folder_txt, downsample=False
): ):
speaker, utterance = fileid.split("_") speaker_id, utterance_id = fileid.split("_")
# Read text # Read text
file_txt = os.path.join(path, folder_txt, speaker, fileid + ext_txt) file_txt = os.path.join(path, folder_txt, speaker_id, fileid + ext_txt)
with open(file_txt) as file_text: with open(file_txt) as file_text:
content = file_text.readlines()[0] utterance = file_text.readlines()[0]
# Read wav # Read wav
file_audio = os.path.join(path, folder_audio, speaker, fileid + ext_audio) file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio)
if downsample: if downsample:
# Legacy # Legacy
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
...@@ -34,16 +34,14 @@ def load_vctk_item( ...@@ -34,16 +34,14 @@ def load_vctk_item(
else: else:
waveform, sample_rate = torchaudio.load(file_audio) waveform, sample_rate = torchaudio.load(file_audio)
return { return waveform, sample_rate, utterance, speaker_id, utterance_id
"speaker_id": speaker,
"utterance_id": utterance,
"utterance": content,
"waveform": waveform,
"sample_rate": sample_rate,
}
class VCTK(Dataset): class VCTK(Dataset):
"""
Create a Dataset for VCTK. Each item is a tuple of the form:
(waveform, sample_rate, utterance, speaker_id, utterance_id)
"""
_folder_txt = "txt" _folder_txt = "txt"
_folder_audio = "wav48" _folder_audio = "wav48"
...@@ -59,17 +57,8 @@ class VCTK(Dataset): ...@@ -59,17 +57,8 @@ class VCTK(Dataset):
downsample=False, downsample=False,
transform=None, transform=None,
target_transform=None, target_transform=None,
return_dict=False,
): ):
if not return_dict:
warnings.warn(
"In the next version, the item returned will be a dictionary. "
"Please use `return_dict=True` to enable this behavior now, "
"and suppress this warning.",
DeprecationWarning,
)
if downsample: if downsample:
warnings.warn( warnings.warn(
"In the next version, transforms will not be part of the dataset. " "In the next version, transforms will not be part of the dataset. "
...@@ -89,7 +78,6 @@ class VCTK(Dataset): ...@@ -89,7 +78,6 @@ class VCTK(Dataset):
self.downsample = downsample self.downsample = downsample
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.return_dict = return_dict
archive = os.path.basename(url) archive = os.path.basename(url)
archive = os.path.join(root, archive) archive = os.path.join(root, archive)
...@@ -122,23 +110,15 @@ class VCTK(Dataset): ...@@ -122,23 +110,15 @@ class VCTK(Dataset):
self._folder_txt, self._folder_txt,
) )
# Legacy # TODO Upon deprecation, uncomment line below and remove following code
waveform = item["waveform"] # return item
waveform, sample_rate, utterance, speaker_id, utterance_id = item
if self.transform is not None: if self.transform is not None:
waveform = self.transform(waveform) waveform = self.transform(waveform)
item["waveform"] = waveform
# Legacy
utterance = item["utterance"]
if self.target_transform is not None: if self.target_transform is not None:
utterance = self.target_transform(utterance) utterance = self.target_transform(utterance)
item["utterance"] = utterance return waveform, sample_rate, utterance, speaker_id, utterance_id
if self.return_dict:
return item
# Legacy
return item["waveform"], item["utterance"]
def __len__(self): def __len__(self):
return len(self._walker) return len(self._walker)
...@@ -11,16 +11,20 @@ FOLDER_IN_ARCHIVE = "waves_yesno" ...@@ -11,16 +11,20 @@ FOLDER_IN_ARCHIVE = "waves_yesno"
def load_yesno_item(fileid, path, ext_audio): def load_yesno_item(fileid, path, ext_audio):
# Read label # Read label
label = fileid.split("_") labels = [int(c) for c in fileid.split("_")]
# Read wav # Read wav
file_audio = os.path.join(path, fileid + ext_audio) file_audio = os.path.join(path, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio) waveform, sample_rate = torchaudio.load(file_audio)
return {"label": label, "waveform": waveform, "sample_rate": sample_rate} return waveform, sample_rate, labels
class YESNO(Dataset): class YESNO(Dataset):
"""
Create a Dataset for YesNo. Each item is a tuple of the form:
(waveform, sample_rate, labels)
"""
_ext_audio = ".wav" _ext_audio = ".wav"
...@@ -32,17 +36,8 @@ class YESNO(Dataset): ...@@ -32,17 +36,8 @@ class YESNO(Dataset):
download=False, download=False,
transform=None, transform=None,
target_transform=None, target_transform=None,
return_dict=False,
): ):
if not return_dict:
warnings.warn(
"In the next version, the item returned will be a dictionary. "
"Please use `return_dict=True` to enable this behavior now, "
"and suppress this warning.",
DeprecationWarning,
)
if transform is not None or target_transform is not None: if transform is not None or target_transform is not None:
warnings.warn( warnings.warn(
"In the next version, transforms will not be part of the dataset. " "In the next version, transforms will not be part of the dataset. "
...@@ -53,7 +48,6 @@ class YESNO(Dataset): ...@@ -53,7 +48,6 @@ class YESNO(Dataset):
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.return_dict = return_dict
archive = os.path.basename(url) archive = os.path.basename(url)
archive = os.path.join(root, archive) archive = os.path.join(root, archive)
...@@ -79,20 +73,15 @@ class YESNO(Dataset): ...@@ -79,20 +73,15 @@ class YESNO(Dataset):
fileid = self._walker[n] fileid = self._walker[n]
item = load_yesno_item(fileid, self._path, self._ext_audio) item = load_yesno_item(fileid, self._path, self._ext_audio)
waveform = item["waveform"] # TODO Upon deprecation, uncomment line below and remove following code
# return item
waveform, sample_rate, labels = item
if self.transform is not None: if self.transform is not None:
waveform = self.transform(waveform) waveform = self.transform(waveform)
item["waveform"] = waveform
label = item["label"]
if self.target_transform is not None: if self.target_transform is not None:
label = self.target_transform(label) labels = self.target_transform(labels)
item["label"] = label return waveform, sample_rate, labels
if self.return_dict:
return item
return item["waveform"], item["label"]
def __len__(self): def __len__(self):
return len(self._walker) return len(self._walker)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment