Unverified Commit 0cf4b8a9 authored by Kshiteej K's avatar Kshiteej K Committed by GitHub
Browse files

Add pathlib.Path support to `commonvoice` (#1027)

parent f1142e65
import os import os
import csv import csv
import random import random
from pathlib import Path
from torchaudio.datasets import commonvoice from torchaudio.datasets import commonvoice
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
...@@ -59,8 +60,7 @@ class TestCommonVoice(TempDirMixin, TorchaudioTestCase): ...@@ -59,8 +60,7 @@ class TestCommonVoice(TempDirMixin, TorchaudioTestCase):
# Append data entry # Append data entry
cls.data.append((normalize_wav(data), cls.sample_rate, dict(zip(cls._headers, content)))) cls.data.append((normalize_wav(data), cls.sample_rate, dict(zip(cls._headers, content))))
def test_commonvoice(self): def _test_commonvoice(self, dataset):
dataset = commonvoice.COMMONVOICE(self.root_dir)
n_ite = 0 n_ite = 0
for i, (waveform, sample_rate, dictionary) in enumerate(dataset): for i, (waveform, sample_rate, dictionary) in enumerate(dataset):
expected_dictionary = self.data[i][2] expected_dictionary = self.data[i][2]
...@@ -70,3 +70,11 @@ class TestCommonVoice(TempDirMixin, TorchaudioTestCase): ...@@ -70,3 +70,11 @@ class TestCommonVoice(TempDirMixin, TorchaudioTestCase):
assert dictionary == expected_dictionary assert dictionary == expected_dictionary
n_ite += 1 n_ite += 1
assert n_ite == len(self.data) assert n_ite == len(self.data)
def test_commonvoice_str(self):
dataset = commonvoice.COMMONVOICE(self.root_dir)
self._test_commonvoice(dataset)
def test_commonvoice_path(self):
dataset = commonvoice.COMMONVOICE(Path(self.root_dir))
self._test_commonvoice(dataset)
import os import os
from typing import List, Dict, Tuple from pathlib import Path
from typing import List, Dict, Tuple, Union
import torchaudio import torchaudio
from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader
...@@ -103,7 +104,7 @@ class COMMONVOICE(Dataset): ...@@ -103,7 +104,7 @@ class COMMONVOICE(Dataset):
"""Create a Dataset for CommonVoice. """Create a Dataset for CommonVoice.
Args: Args:
root (str): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
tsv (str, optional): The name of the tsv file used to construct the metadata. tsv (str, optional): The name of the tsv file used to construct the metadata.
(default: ``"train.tsv"``) (default: ``"train.tsv"``)
url (str, optional): The URL to download the dataset from, or the language of url (str, optional): The URL to download the dataset from, or the language of
...@@ -129,7 +130,7 @@ class COMMONVOICE(Dataset): ...@@ -129,7 +130,7 @@ class COMMONVOICE(Dataset):
_folder_audio = "clips" _folder_audio = "clips"
def __init__(self, def __init__(self,
root: str, root: Union[str, Path],
tsv: str = TSV, tsv: str = TSV,
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
...@@ -186,6 +187,9 @@ class COMMONVOICE(Dataset): ...@@ -186,6 +187,9 @@ class COMMONVOICE(Dataset):
base_url = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com" base_url = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com"
url = os.path.join(base_url, version, language + ext_archive) url = os.path.join(base_url, version, language + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(root, basename) archive = os.path.join(root, basename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment