Commit 9a013fdb authored by Zhaoheng Ni's avatar Zhaoheng Ni
Browse files

Add file_name to the returned item in Snips dataset (#2775)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2775

Reviewed By: carolineechen

Differential Revision: D40481144

Pulled By: nateanl

fbshipit-source-id: 5d0fb2478767704603a3ec28d74160e7892d4d0e
parent 88a8dd4d
......@@ -55,7 +55,7 @@ def _get_mocked_samples(dataset_dir: str, subset: str, seed: int):
transcript, iob, intent = f"{spk}XXX", f"{spk}YYY", f"{spk}ZZZ"
label = "BOS " + transcript + " EOS\tO " + iob + " " + intent
_save_label(label_path, wav_stem, label)
samples.append((waveform, _SAMPLE_RATE, transcript, iob, intent))
samples.append((waveform, _SAMPLE_RATE, wav_stem, transcript, iob, intent))
return samples
......@@ -100,12 +100,13 @@ class TestSnips(TempDirMixin, TorchaudioTestCase):
def _testSnips(self, dataset, data_samples):
num_samples = 0
for i, (data, sample_rate, transcript, iob, intent) in enumerate(dataset):
for i, (data, sample_rate, file_name, transcript, iob, intent) in enumerate(dataset):
self.assertEqual(data, data_samples[i][0])
assert sample_rate == data_samples[i][1]
assert transcript == data_samples[i][2]
assert iob == data_samples[i][3]
assert intent == data_samples[i][4]
assert file_name == data_samples[i][2]
assert transcript == data_samples[i][3]
assert iob == data_samples[i][4]
assert intent == data_samples[i][5]
num_samples += 1
assert num_samples == len(data_samples)
......
......@@ -112,6 +112,8 @@ class Snips(Dataset):
Path to audio
int:
Sample rate
str:
File name
str:
Transcription of audio
str:
......@@ -123,7 +125,7 @@ class Snips(Dataset):
relpath = os.path.relpath(audio_path, self._path)
file_name = audio_path.with_suffix("").name
transcript, iob, intent = self.labels[file_name]
return relpath, _SAMPLE_RATE, transcript, iob, intent
return relpath, _SAMPLE_RATE, file_name, transcript, iob, intent
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]:
"""Load the n-th sample from the dataset.
......@@ -138,6 +140,8 @@ class Snips(Dataset):
Waveform
int:
Sample rate
str:
File name
str:
Transcription of audio
str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment