Commit 5bf73b59 authored by Yu Shi's avatar Yu Shi Committed by Facebook GitHub Bot
Browse files

Fix argument validation in TorchAudio datasets (#2571)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2571

Per T127106783, replace `assert` statement with `if _ then raise` statement to enforce the assertion even in optimized mode

Reviewed By: mthrok

Differential Revision: D38123481

fbshipit-source-id: 19321f7467bfd993b38bd9e44fcd01e5f5e64b87
parent 379487de
......@@ -14,7 +14,8 @@ def load_commonvoice_item(
# Each line as the following data:
# client_id, path, sentence, up_votes, down_votes, age, gender, accent
assert header[1] == "path"
if header[1] != "path":
raise ValueError(f"expect `header[1]` to be 'path', but got {header[1]}")
fileid = line[1]
filename = os.path.join(path, folder_audio, fileid)
if not filename.endswith(ext_audio):
......
......@@ -17,7 +17,8 @@ class FluentSpeechCommands(Dataset):
"""
def __init__(self, root: Union[str, Path], subset: str = "train"):
assert subset in ["train", "valid", "test"], "`subset` must be one of ['train', 'valid', 'test']"
if subset not in ["train", "valid", "test"]:
raise ValueError("`subset` must be one of ['train', 'valid', 'test']")
root = os.fspath(root)
self._path = os.path.join(root, "fluent_speech_commands_dataset")
......
......@@ -1036,9 +1036,8 @@ class GTZAN(Dataset):
self.download = download
self.subset = subset
assert subset is None or subset in ["training", "validation", "testing"], (
"When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}."
)
if subset is not None and subset not in ["training", "validation", "testing"]:
raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")
archive = os.path.basename(url)
archive = os.path.join(root, archive)
......
......@@ -63,7 +63,8 @@ class LibriLightLimited(Dataset):
subset: str = "10min",
download: bool = False,
) -> None:
assert subset in ["10min", "1h", "10h"], "`subset` must be one of ['10min', '1h', '10h']"
if subset not in ["10min", "1h", "10h"]:
raise ValueError("`subset` must be one of ['10min', '1h', '10h']")
root = os.fspath(root)
self._path = os.path.join(root, _ARCHIVE_NAME)
......
......@@ -62,11 +62,10 @@ class MUSDB_HQ(Dataset):
archive = os.path.join(root, basename)
basename = basename.rsplit(".", 2)[0]
assert subset in ["test", "train"], "`subset` must be one of ['test', 'train']"
assert self.split is None or self.split in [
"train",
"validation",
], "`split` must be one of ['train', 'validation']"
if subset not in ["test", "train"]:
raise ValueError("`subset` must be one of ['test', 'train']")
if self.split is not None and self.split not in ["train", "validation"]:
raise ValueError("`split` must be one of ['train', 'validation']")
base_path = os.path.join(root, basename)
self._path = os.path.join(base_path, subset)
if not os.path.isdir(self._path):
......@@ -89,11 +88,13 @@ class MUSDB_HQ(Dataset):
for source in self.sources:
track = self._get_track(name, source)
wav, sr = torchaudio.load(str(track))
assert sr == _SAMPLE_RATE, f"expected sample rate {_SAMPLE_RATE}, but got {sr}"
if sr != _SAMPLE_RATE:
raise ValueError(f"expected sample rate {_SAMPLE_RATE}, but got {sr}")
if num_frames is None:
num_frames = wav.shape[-1]
else:
assert wav.shape[-1] == num_frames, "num_frames do not match across sources"
if wav.shape[-1] != num_frames:
raise ValueError("num_frames do not match across sources")
wavs.append(wav)
stacked = torch.stack(wavs)
......
......@@ -42,9 +42,11 @@ class QUESST14(Dataset):
language: Optional[str] = "nnenglish",
download: bool = False,
) -> None:
assert subset in ["docs", "dev", "eval"], "`subset` must be one of ['docs', 'dev', 'eval']"
if subset not in ["docs", "dev", "eval"]:
raise ValueError("`subset` must be one of ['docs', 'dev', 'eval']")
assert language is None or language in _LANGUAGES, f"`language` must be None or one of {str(_LANGUAGES)}"
if language is not None and language not in _LANGUAGES:
raise ValueError(f"`language` must be None or one of {str(_LANGUAGES)}")
# Get string representation of 'root'
root = os.fspath(root)
......
......@@ -79,9 +79,8 @@ class SPEECHCOMMANDS(Dataset):
subset: Optional[str] = None,
) -> None:
assert subset is None or subset in ["training", "validation", "testing"], (
"When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}."
)
if subset is not None and subset not in ["training", "validation", "testing"]:
raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")
if url in [
"speech_commands_v0.01",
......
......@@ -137,7 +137,8 @@ class VoxCeleb1Identification(VoxCeleb1):
self, root: Union[str, Path], subset: str = "train", meta_url: str = _IDEN_SPLIT_URL, download: bool = False
) -> None:
super().__init__(root, download)
assert subset in ["train", "dev", "test"], "`subset` must be one of ['train', 'dev', 'test']"
if subset not in ["train", "dev", "test"]:
raise ValueError("`subset` must be one of ['train', 'dev', 'test']")
# download the iden_split.txt to get the train, dev, test lists.
meta_list_path = os.path.join(root, os.path.basename(meta_url))
if not os.path.exists(meta_list_path):
......@@ -205,7 +206,8 @@ class VoxCeleb1Verification(VoxCeleb1):
file_id_spk2 = _get_file_id(file_path_spk2, self._ext_audio)
waveform_spk1, sample_rate = torchaudio.load(os.path.join(self._path, file_path_spk1))
waveform_spk2, sample_rate2 = torchaudio.load(os.path.join(self._path, file_path_spk2))
assert sample_rate == sample_rate2
if sample_rate != sample_rate2:
raise ValueError(f"`sample_rate` {sample_rate} is not equal to `sample_rate2` {sample_rate2}")
return (waveform_spk1, waveform_spk2, sample_rate, label, file_id_spk1, file_id_spk2)
def __len__(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment