Unverified Commit 37b4e136 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Add pathlib support for TEDLIUM (#1045)

parent f3b9208f
import os
import platform
import unittest
from pathlib import Path
from torchaudio.datasets import tedlium
......@@ -93,9 +94,7 @@ class Tedlium(TempDirMixin):
cls.samples[release].append(sample)
seed += 1
def test_tedlium_release1(self):
release = "release1"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
def _test_tedlium(self, dataset, release):
num_samples = 0
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
......@@ -113,45 +112,25 @@ class Tedlium(TempDirMixin):
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME
def test_tedlium_release2(self):
release = "release2"
def test_tedlium_release1_str(self):
release = "release1"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[release][i][1]
assert transcript == self.samples[release][i][2]
assert talk_id == self.samples[release][i][3]
assert speaker_id == self.samples[release][i][4]
assert identifier == self.samples[release][i][5]
num_samples += 1
self._test_tedlium(dataset, release)
assert num_samples == len(self.samples[release])
def test_tedlium_release1_path(self):
release = "release1"
dataset = tedlium.TEDLIUM(Path(self.root_dir), release=release)
self._test_tedlium(dataset, release)
dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME
def test_tedlium_release2(self):
release = "release2"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
self._test_tedlium(dataset, release)
def test_tedlium_release3(self):
release = "release3"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[release][i][1]
assert transcript == self.samples[release][i][2]
assert talk_id == self.samples[release][i][3]
assert speaker_id == self.samples[release][i][4]
assert identifier == self.samples[release][i][5]
num_samples += 1
assert num_samples == len(self.samples[release])
dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME
self._test_tedlium(dataset, release)
class TestTedliumSoundfile(Tedlium, TorchaudioTestCase):
......
import os
from typing import Tuple
from typing import Tuple, Union
from pathlib import Path
import torchaudio
from torch import Tensor
......@@ -46,7 +47,7 @@ class TEDLIUM(Dataset):
Create a Dataset for Tedlium. It supports releases 1,2 and 3.
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str or Path): Path to the directory where the dataset is found or downloaded.
release (str, optional): Release version.
Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``.
(default: ``"release1"``).
......@@ -56,7 +57,12 @@ class TEDLIUM(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
def __init__(
self, root: str, release: str = "release1", subset: str = None, download: bool = False, audio_ext=".sph"
self,
root: Union[str, Path],
release: str = "release1",
subset: str = None,
download: bool = False,
audio_ext=".sph"
) -> None:
self._ext_audio = audio_ext
if release in _RELEASE_CONFIGS.keys():
......@@ -78,6 +84,9 @@ class TEDLIUM(Dataset):
)
)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment