Commit 09daa438 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Fix fluent test for windows (#2510)

Summary:
fluent dataset test currently fails on windows, due to new line generation in csv writer in testing and incorrect path parsing in dataset impl.

Pull Request resolved: https://github.com/pytorch/audio/pull/2510

Reviewed By: carolineechen

Differential Revision: D37573203

Pulled By: mthrok

fbshipit-source-id: 4868bc649690c7e596b002686c6128ce735d3564
parent ef8bd7b6
...@@ -2,7 +2,6 @@ import csv ...@@ -2,7 +2,6 @@ import csv
import os import os
import random import random
import string import string
from pathlib import Path
from torchaudio.datasets import fluentcommands from torchaudio.datasets import fluentcommands
from torchaudio_unittest.common_utils import get_whitenoise, save_wav, TempDirMixin, TorchaudioTestCase from torchaudio_unittest.common_utils import get_whitenoise, save_wav, TempDirMixin, TorchaudioTestCase
...@@ -49,7 +48,7 @@ def _gen_csv(dataset_dir: str, subset: str, init_seed: int): ...@@ -49,7 +48,7 @@ def _gen_csv(dataset_dir: str, subset: str, init_seed: int):
idx += 1 idx += 1
csv_path = os.path.join(dataset_dir, "data", f"{subset}_data.csv") csv_path = os.path.join(dataset_dir, "data", f"{subset}_data.csv")
with open(csv_path, "w") as csv_file: with open(csv_path, "w", newline="") as csv_file:
file_writer = csv.writer(csv_file) file_writer = csv.writer(csv_file)
file_writer.writerows(data) file_writer.writerows(data)
...@@ -73,14 +72,15 @@ def _save_samples(dataset_dir: str, subset: str, seed: int): ...@@ -73,14 +72,15 @@ def _save_samples(dataset_dir: str, subset: str, seed: int):
n_channels=1, n_channels=1,
seed=seed, seed=seed,
) )
filename = row[path_idx] path = row[path_idx]
wav_file = os.path.join(dataset_dir, filename) filename = path.split("/")[-1]
save_wav(wav_file, wav, SAMPLE_RATE) filename = filename.split(".")[0]
path = Path(wav_file).stem
speaker_id, transcription, act, obj, loc = row[2:] speaker_id, transcription, act, obj, loc = row[2:]
sample = wav, SAMPLE_RATE, path, speaker_id, transcription, act, obj, loc wav_file = os.path.join(dataset_dir, "wavs", "speakers", speaker_id, f"{filename}.wav")
save_wav(wav_file, wav, SAMPLE_RATE)
sample = wav, SAMPLE_RATE, filename, speaker_id, transcription, act, obj, loc
samples.append(sample) samples.append(sample)
seed += 1 seed += 1
...@@ -91,6 +91,7 @@ def _save_samples(dataset_dir: str, subset: str, seed: int): ...@@ -91,6 +91,7 @@ def _save_samples(dataset_dir: str, subset: str, seed: int):
def get_mock_dataset(dataset_dir: str): def get_mock_dataset(dataset_dir: str):
data_folder = os.path.join(dataset_dir, "data") data_folder = os.path.join(dataset_dir, "data")
wav_folder = os.path.join(dataset_dir, "wavs", "speakers") wav_folder = os.path.join(dataset_dir, "wavs", "speakers")
os.makedirs(data_folder, exist_ok=True) os.makedirs(data_folder, exist_ok=True)
os.makedirs(wav_folder, exist_ok=True) os.makedirs(wav_folder, exist_ok=True)
......
...@@ -22,6 +22,9 @@ class FluentSpeechCommands(Dataset): ...@@ -22,6 +22,9 @@ class FluentSpeechCommands(Dataset):
root = os.fspath(root) root = os.fspath(root)
self._path = os.path.join(root, "fluent_speech_commands_dataset") self._path = os.path.join(root, "fluent_speech_commands_dataset")
if not os.path.isdir(self._path):
raise RuntimeError("Dataset not found.")
subset_path = os.path.join(self._path, "data", f"{subset}_data.csv") subset_path = os.path.join(self._path, "data", f"{subset}_data.csv")
with open(subset_path) as subset_csv: with open(subset_path) as subset_csv:
subset_reader = csv.reader(subset_csv) subset_reader = csv.reader(subset_csv)
...@@ -40,15 +43,17 @@ class FluentSpeechCommands(Dataset): ...@@ -40,15 +43,17 @@ class FluentSpeechCommands(Dataset):
n (int): The index of the sample to be loaded n (int): The index of the sample to be loaded
Returns: Returns:
(Tensor, int, Path, int, str, str, str, str): (Tensor, int, str, int, str, str, str, str):
``(waveform, sample_rate, path, speaker_id, transcription, action, object, location)`` ``(waveform, sample_rate, file_name, speaker_id, transcription, action, object, location)``
""" """
sample = self.data[n] sample = self.data[n]
wav_path = os.path.join(self._path, sample[self.header.index("path")])
wav, sample_rate = torchaudio.load(wav_path)
path = Path(wav_path).stem file_name = sample[self.header.index("path")].split("/")[-1]
file_name = file_name.split(".")[0]
speaker_id, transcription, action, obj, location = sample[2:] speaker_id, transcription, action, obj, location = sample[2:]
return wav, sample_rate, path, speaker_id, transcription, action, obj, location wav_path = os.path.join(self._path, "wavs", "speakers", speaker_id, f"{file_name}.wav")
wav, sample_rate = torchaudio.load(wav_path)
return wav, sample_rate, file_name, speaker_id, transcription, action, obj, location
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment