Unverified Commit 9b288109 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Fix common voice dataset (#498)

* fix download

* fix reading tsv archive

* add new languages

* maintain same structure as other datasets

* update CommonVoice Tests

* fix

* change directory name

* remove extra line
parent 606ae324
...@@ -29,13 +29,11 @@ class TestDatasets(unittest.TestCase): ...@@ -29,13 +29,11 @@ class TestDatasets(unittest.TestCase):
data[0] data[0]
def test_commonvoice(self): def test_commonvoice(self):
path = os.path.join(self.path, "commonvoice") data = COMMONVOICE(self.path, url="tatar")
data = COMMONVOICE(path, "train.tsv", "tatar")
data[0] data[0]
def test_commonvoice_diskcache(self): def test_commonvoice_diskcache(self):
path = os.path.join(self.path, "commonvoice") data = COMMONVOICE(self.path, url="tatar")
data = COMMONVOICE(path, "train.tsv", "tatar")
data = diskcache_iterator(data) data = diskcache_iterator(data)
# Save # Save
data[0] data[0]
...@@ -43,8 +41,7 @@ class TestDatasets(unittest.TestCase): ...@@ -43,8 +41,7 @@ class TestDatasets(unittest.TestCase):
data[0] data[0]
def test_commonvoice_bg(self): def test_commonvoice_bg(self):
path = os.path.join(self.path, "commonvoice") data = COMMONVOICE(self.path, url="tatar")
data = COMMONVOICE(path, "train.tsv", "tatar")
data = bg_iterator(data, 5) data = bg_iterator(data, 5)
for _ in data: for _ in data:
pass pass
......
...@@ -13,7 +13,9 @@ from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv ...@@ -13,7 +13,9 @@ from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv
# train.tsv # train.tsv
# validated.tsv # validated.tsv
FOLDER_IN_ARCHIVE = "CommonVoice"
URL = "english" URL = "english"
VERSION = "cv-corpus-4-2019-12-10"
TSV = "train.tsv" TSV = "train.tsv"
...@@ -45,7 +47,12 @@ class COMMONVOICE(Dataset): ...@@ -45,7 +47,12 @@ class COMMONVOICE(Dataset):
_ext_audio = ".mp3" _ext_audio = ".mp3"
_folder_audio = "clips" _folder_audio = "clips"
def __init__(self, root, tsv=TSV, url=URL, download=False): def __init__(self, root,
tsv=TSV,
url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
version=VERSION,
download=False):
languages = { languages = {
"tatar": "tt", "tatar": "tt",
...@@ -68,6 +75,7 @@ class COMMONVOICE(Dataset): ...@@ -68,6 +75,7 @@ class COMMONVOICE(Dataset):
"esperanto": "eo", "esperanto": "eo",
"estonian": "et", "estonian": "et",
"persian": "fa", "persian": "fa",
"portuguese": "pt",
"basque": "eu", "basque": "eu",
"spanish": "es", "spanish": "es",
"chinese": "zh-CN", "chinese": "zh-CN",
...@@ -77,29 +85,40 @@ class COMMONVOICE(Dataset): ...@@ -77,29 +85,40 @@ class COMMONVOICE(Dataset):
"kinyarwanda": "rw", "kinyarwanda": "rw",
"swedish": "sv-SE", "swedish": "sv-SE",
"russian": "ru", "russian": "ru",
"indonesian": "id",
"arabic": "ar",
"tamil": "ta",
"interlingua": "ia",
"latvian": "lv",
"japanese": "ja",
"votic": "vot",
"abkhaz": "ab",
"cantonese": "zh-HK",
"romansh sursilvan": "rm-sursilv"
} }
if url is languages: if url in languages:
ext_archive = ".tar.gz" ext_archive = ".tar.gz"
language = languages[url] language = languages[url]
base_url = ( base_url = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com"
"https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4" url = os.path.join(base_url, version, language + ext_archive)
+ ".s3.amazonaws.com/cv-corpus-3/"
)
url = base_url + language + ext_archive
archive = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(root, archive) archive = os.path.join(root, basename)
self._path = root
basename = basename.rsplit(".", 2)[0]
folder_in_archive = os.path.join(folder_in_archive, version, basename)
self._path = os.path.join(root, folder_in_archive)
if download: if download:
if not os.path.isdir(self._path): if not os.path.isdir(self._path):
if not os.path.isfile(archive): if not os.path.isfile(archive):
download_url(url, root) download_url(url, root)
extract_archive(archive) extract_archive(archive, self._path)
self._tsv = os.path.join(root, tsv) self._tsv = os.path.join(root, folder_in_archive, tsv)
with open(self._tsv, "r") as tsv: with open(self._tsv, "r") as tsv:
walker = unicode_csv_reader(tsv, delimiter="\t") walker = unicode_csv_reader(tsv, delimiter="\t")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment