Unverified Commit 37b4e136 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Add pathlib support for TEDLIUM (#1045)

parent f3b9208f
import os import os
import platform import platform
import unittest import unittest
from pathlib import Path
from torchaudio.datasets import tedlium from torchaudio.datasets import tedlium
...@@ -93,9 +94,7 @@ class Tedlium(TempDirMixin): ...@@ -93,9 +94,7 @@ class Tedlium(TempDirMixin):
cls.samples[release].append(sample) cls.samples[release].append(sample)
seed += 1 seed += 1
def test_tedlium_release1(self): def _test_tedlium(self, dataset, release):
release = "release1"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0 num_samples = 0
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset): for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8) self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
...@@ -113,45 +112,25 @@ class Tedlium(TempDirMixin): ...@@ -113,45 +112,25 @@ class Tedlium(TempDirMixin):
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()] phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME assert phoenemes == PHONEME
def test_tedlium_release2(self): def test_tedlium_release1_str(self):
release = "release2" release = "release1"
dataset = tedlium.TEDLIUM(self.root_dir, release=release) dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0 self._test_tedlium(dataset, release)
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[release][i][1]
assert transcript == self.samples[release][i][2]
assert talk_id == self.samples[release][i][3]
assert speaker_id == self.samples[release][i][4]
assert identifier == self.samples[release][i][5]
num_samples += 1
assert num_samples == len(self.samples[release]) def test_tedlium_release1_path(self):
release = "release1"
dataset = tedlium.TEDLIUM(Path(self.root_dir), release=release)
self._test_tedlium(dataset, release)
dataset._dict_path = os.path.join(dataset._path, f"{release}.dic") def test_tedlium_release2(self):
phoneme_dict = dataset.phoneme_dict release = "release2"
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()] dataset = tedlium.TEDLIUM(self.root_dir, release=release)
assert phoenemes == PHONEME self._test_tedlium(dataset, release)
def test_tedlium_release3(self): def test_tedlium_release3(self):
release = "release3" release = "release3"
dataset = tedlium.TEDLIUM(self.root_dir, release=release) dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0 self._test_tedlium(dataset, release)
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[release][i][1]
assert transcript == self.samples[release][i][2]
assert talk_id == self.samples[release][i][3]
assert speaker_id == self.samples[release][i][4]
assert identifier == self.samples[release][i][5]
num_samples += 1
assert num_samples == len(self.samples[release])
dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME
class TestTedliumSoundfile(Tedlium, TorchaudioTestCase): class TestTedliumSoundfile(Tedlium, TorchaudioTestCase):
......
import os import os
from typing import Tuple from typing import Tuple, Union
from pathlib import Path
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
...@@ -46,7 +47,7 @@ class TEDLIUM(Dataset): ...@@ -46,7 +47,7 @@ class TEDLIUM(Dataset):
Create a Dataset for Tedlium. It supports releases 1,2 and 3. Create a Dataset for Tedlium. It supports releases 1,2 and 3.
Args: Args:
root (str): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
release (str, optional): Release version. release (str, optional): Release version.
Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``. Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``.
(default: ``"release1"``). (default: ``"release1"``).
...@@ -56,7 +57,12 @@ class TEDLIUM(Dataset): ...@@ -56,7 +57,12 @@ class TEDLIUM(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``). Whether to download the dataset if it is not found at root path. (default: ``False``).
""" """
def __init__( def __init__(
self, root: str, release: str = "release1", subset: str = None, download: bool = False, audio_ext=".sph" self,
root: Union[str, Path],
release: str = "release1",
subset: str = None,
download: bool = False,
audio_ext=".sph"
) -> None: ) -> None:
self._ext_audio = audio_ext self._ext_audio = audio_ext
if release in _RELEASE_CONFIGS.keys(): if release in _RELEASE_CONFIGS.keys():
...@@ -78,6 +84,9 @@ class TEDLIUM(Dataset): ...@@ -78,6 +84,9 @@ class TEDLIUM(Dataset):
) )
) )
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(root, basename) archive = os.path.join(root, basename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment