Commit e8ae0ad2 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add file_name to the returned item in Snips dataset (#2775)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2775

Reviewed By: carolineechen

Differential Revision: D40481144

Pulled By: nateanl

fbshipit-source-id: 5d0fb2478767704603a3ec28d74160e7892d4d0e
parent 0c8dfe96
...@@ -55,7 +55,7 @@ def _get_mocked_samples(dataset_dir: str, subset: str, seed: int): ...@@ -55,7 +55,7 @@ def _get_mocked_samples(dataset_dir: str, subset: str, seed: int):
transcript, iob, intent = f"{spk}XXX", f"{spk}YYY", f"{spk}ZZZ" transcript, iob, intent = f"{spk}XXX", f"{spk}YYY", f"{spk}ZZZ"
label = "BOS " + transcript + " EOS\tO " + iob + " " + intent label = "BOS " + transcript + " EOS\tO " + iob + " " + intent
_save_label(label_path, wav_stem, label) _save_label(label_path, wav_stem, label)
samples.append((waveform, _SAMPLE_RATE, transcript, iob, intent)) samples.append((waveform, _SAMPLE_RATE, wav_stem, transcript, iob, intent))
return samples return samples
...@@ -100,12 +100,13 @@ class TestSnips(TempDirMixin, TorchaudioTestCase): ...@@ -100,12 +100,13 @@ class TestSnips(TempDirMixin, TorchaudioTestCase):
def _testSnips(self, dataset, data_samples): def _testSnips(self, dataset, data_samples):
num_samples = 0 num_samples = 0
for i, (data, sample_rate, transcript, iob, intent) in enumerate(dataset): for i, (data, sample_rate, file_name, transcript, iob, intent) in enumerate(dataset):
self.assertEqual(data, data_samples[i][0]) self.assertEqual(data, data_samples[i][0])
assert sample_rate == data_samples[i][1] assert sample_rate == data_samples[i][1]
assert transcript == data_samples[i][2] assert file_name == data_samples[i][2]
assert iob == data_samples[i][3] assert transcript == data_samples[i][3]
assert intent == data_samples[i][4] assert iob == data_samples[i][4]
assert intent == data_samples[i][5]
num_samples += 1 num_samples += 1
assert num_samples == len(data_samples) assert num_samples == len(data_samples)
......
...@@ -112,6 +112,8 @@ class Snips(Dataset): ...@@ -112,6 +112,8 @@ class Snips(Dataset):
Path to audio Path to audio
int: int:
Sample rate Sample rate
str:
File name
str: str:
Transcription of audio Transcription of audio
str: str:
...@@ -123,7 +125,7 @@ class Snips(Dataset): ...@@ -123,7 +125,7 @@ class Snips(Dataset):
relpath = os.path.relpath(audio_path, self._path) relpath = os.path.relpath(audio_path, self._path)
file_name = audio_path.with_suffix("").name file_name = audio_path.with_suffix("").name
transcript, iob, intent = self.labels[file_name] transcript, iob, intent = self.labels[file_name]
return relpath, _SAMPLE_RATE, transcript, iob, intent return relpath, _SAMPLE_RATE, file_name, transcript, iob, intent
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]: def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -138,6 +140,8 @@ class Snips(Dataset): ...@@ -138,6 +140,8 @@ class Snips(Dataset):
Waveform Waveform
int: int:
Sample rate Sample rate
str:
File name
str: str:
Transcription of audio Transcription of audio
str: str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment