Unverified Commit 0cf4b8a9 authored by Kshiteej K's avatar Kshiteej K Committed by GitHub
Browse files

Add pathlib.Path support to `commonvoice` (#1027)

parent f1142e65
import os
import csv
import random
from pathlib import Path
from torchaudio.datasets import commonvoice
from torchaudio_unittest.common_utils import (
......@@ -59,8 +60,7 @@ class TestCommonVoice(TempDirMixin, TorchaudioTestCase):
# Append data entry
cls.data.append((normalize_wav(data), cls.sample_rate, dict(zip(cls._headers, content))))
def test_commonvoice(self):
dataset = commonvoice.COMMONVOICE(self.root_dir)
def _test_commonvoice(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, dictionary) in enumerate(dataset):
expected_dictionary = self.data[i][2]
......@@ -70,3 +70,11 @@ class TestCommonVoice(TempDirMixin, TorchaudioTestCase):
assert dictionary == expected_dictionary
n_ite += 1
assert n_ite == len(self.data)
def test_commonvoice_str(self):
dataset = commonvoice.COMMONVOICE(self.root_dir)
self._test_commonvoice(dataset)
def test_commonvoice_path(self):
dataset = commonvoice.COMMONVOICE(Path(self.root_dir))
self._test_commonvoice(dataset)
import os
from typing import List, Dict, Tuple
from pathlib import Path
from typing import List, Dict, Tuple, Union
import torchaudio
from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader
......@@ -103,7 +104,7 @@ class COMMONVOICE(Dataset):
"""Create a Dataset for CommonVoice.
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str or Path): Path to the directory where the dataset is found or downloaded.
tsv (str, optional): The name of the tsv file used to construct the metadata.
(default: ``"train.tsv"``)
url (str, optional): The URL to download the dataset from, or the language of
......@@ -129,7 +130,7 @@ class COMMONVOICE(Dataset):
_folder_audio = "clips"
def __init__(self,
root: str,
root: Union[str, Path],
tsv: str = TSV,
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
......@@ -186,6 +187,9 @@ class COMMONVOICE(Dataset):
base_url = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com"
url = os.path.join(base_url, version, language + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment