Commit ceee6912 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Update QUESST14 getitem (#2435)

Summary:
update QUESST14 getitem to include docstrings and additionally return sample rate

Pull Request resolved: https://github.com/pytorch/audio/pull/2435

Reviewed By: nateanl

Differential Revision: D36864254

Pulled By: carolineechen

fbshipit-source-id: 9e68bbc5de27ad2f32f6b298414103c4f6784801
parent d2ecba98
...@@ -46,7 +46,7 @@ def _save_sample(dataset_dir, folder, language, index, sample_rate, seed): ...@@ -46,7 +46,7 @@ def _save_sample(dataset_dir, folder, language, index, sample_rate, seed):
) )
save_wav(file_path, data, sample_rate) save_wav(file_path, data, sample_rate)
sample = (data, Path(file_path).with_suffix("").name) sample = (data, sample_rate, Path(file_path).with_suffix("").name)
# add audio files and language data to language key files # add audio files and language data to language key files
scoring_path = os.path.join(dataset_dir, "scoring") scoring_path = os.path.join(dataset_dir, "scoring")
...@@ -133,9 +133,10 @@ class TestQuesst14(TempDirMixin, TorchaudioTestCase): ...@@ -133,9 +133,10 @@ class TestQuesst14(TempDirMixin, TorchaudioTestCase):
def _testQuesst14(self, dataset, data_samples): def _testQuesst14(self, dataset, data_samples):
num_samples = 0 num_samples = 0
for i, (data, name) in enumerate(dataset): for i, (data, sample_rate, name) in enumerate(dataset):
self.assertEqual(data, data_samples[i][0]) self.assertEqual(data, data_samples[i][0])
assert name == data_samples[i][1] assert sample_rate == data_samples[i][1]
assert name == data_samples[i][2]
num_samples += 1 num_samples += 1
assert num_samples == len(data_samples) assert num_samples == len(data_samples)
......
...@@ -71,10 +71,18 @@ class QUESST14(Dataset): ...@@ -71,10 +71,18 @@ class QUESST14(Dataset):
def _load_sample(self, n: int) -> Tuple[torch.Tensor, str]: def _load_sample(self, n: int) -> Tuple[torch.Tensor, str]:
audio_path = self.data[n] audio_path = self.data[n]
wav, _ = torchaudio.load(audio_path) wav, sample_rate = torchaudio.load(audio_path)
return wav, audio_path.with_suffix("").name return wav, sample_rate, audio_path.with_suffix("").name
def __getitem__(self, n: int) -> Tuple[torch.Tensor, str]: def __getitem__(self, n: int) -> Tuple[torch.Tensor, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, str): ``(waveform, sample_rate, file_name)``
"""
return self._load_sample(n) return self._load_sample(n)
def __len__(self) -> int: def __len__(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment