"examples/vscode:/vscode.git/clone" did not exist on "3efb5d8ecf7d748655e2199d120a40888ece2282"
Commit 5bf73b59 authored by Yu Shi's avatar Yu Shi Committed by Facebook GitHub Bot
Browse files

Fix argument validation in TorchAudio datasets (#2571)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2571

Per T127106783, replace `assert` statement with `if _ then raise` statement to enforce the assertion even in optimized mode

Reviewed By: mthrok

Differential Revision: D38123481

fbshipit-source-id: 19321f7467bfd993b38bd9e44fcd01e5f5e64b87
parent 379487de
...@@ -14,7 +14,8 @@ def load_commonvoice_item( ...@@ -14,7 +14,8 @@ def load_commonvoice_item(
# Each line as the following data: # Each line as the following data:
# client_id, path, sentence, up_votes, down_votes, age, gender, accent # client_id, path, sentence, up_votes, down_votes, age, gender, accent
assert header[1] == "path" if header[1] != "path":
raise ValueError(f"expect `header[1]` to be 'path', but got {header[1]}")
fileid = line[1] fileid = line[1]
filename = os.path.join(path, folder_audio, fileid) filename = os.path.join(path, folder_audio, fileid)
if not filename.endswith(ext_audio): if not filename.endswith(ext_audio):
......
...@@ -17,7 +17,8 @@ class FluentSpeechCommands(Dataset): ...@@ -17,7 +17,8 @@ class FluentSpeechCommands(Dataset):
""" """
def __init__(self, root: Union[str, Path], subset: str = "train"): def __init__(self, root: Union[str, Path], subset: str = "train"):
assert subset in ["train", "valid", "test"], "`subset` must be one of ['train', 'valid', 'test']" if subset not in ["train", "valid", "test"]:
raise ValueError("`subset` must be one of ['train', 'valid', 'test']")
root = os.fspath(root) root = os.fspath(root)
self._path = os.path.join(root, "fluent_speech_commands_dataset") self._path = os.path.join(root, "fluent_speech_commands_dataset")
......
...@@ -1036,9 +1036,8 @@ class GTZAN(Dataset): ...@@ -1036,9 +1036,8 @@ class GTZAN(Dataset):
self.download = download self.download = download
self.subset = subset self.subset = subset
assert subset is None or subset in ["training", "validation", "testing"], ( if subset is not None and subset not in ["training", "validation", "testing"]:
"When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}." raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")
)
archive = os.path.basename(url) archive = os.path.basename(url)
archive = os.path.join(root, archive) archive = os.path.join(root, archive)
......
...@@ -63,7 +63,8 @@ class LibriLightLimited(Dataset): ...@@ -63,7 +63,8 @@ class LibriLightLimited(Dataset):
subset: str = "10min", subset: str = "10min",
download: bool = False, download: bool = False,
) -> None: ) -> None:
assert subset in ["10min", "1h", "10h"], "`subset` must be one of ['10min', '1h', '10h']" if subset not in ["10min", "1h", "10h"]:
raise ValueError("`subset` must be one of ['10min', '1h', '10h']")
root = os.fspath(root) root = os.fspath(root)
self._path = os.path.join(root, _ARCHIVE_NAME) self._path = os.path.join(root, _ARCHIVE_NAME)
......
...@@ -62,11 +62,10 @@ class MUSDB_HQ(Dataset): ...@@ -62,11 +62,10 @@ class MUSDB_HQ(Dataset):
archive = os.path.join(root, basename) archive = os.path.join(root, basename)
basename = basename.rsplit(".", 2)[0] basename = basename.rsplit(".", 2)[0]
assert subset in ["test", "train"], "`subset` must be one of ['test', 'train']" if subset not in ["test", "train"]:
assert self.split is None or self.split in [ raise ValueError("`subset` must be one of ['test', 'train']")
"train", if self.split is not None and self.split not in ["train", "validation"]:
"validation", raise ValueError("`split` must be one of ['train', 'validation']")
], "`split` must be one of ['train', 'validation']"
base_path = os.path.join(root, basename) base_path = os.path.join(root, basename)
self._path = os.path.join(base_path, subset) self._path = os.path.join(base_path, subset)
if not os.path.isdir(self._path): if not os.path.isdir(self._path):
...@@ -89,11 +88,13 @@ class MUSDB_HQ(Dataset): ...@@ -89,11 +88,13 @@ class MUSDB_HQ(Dataset):
for source in self.sources: for source in self.sources:
track = self._get_track(name, source) track = self._get_track(name, source)
wav, sr = torchaudio.load(str(track)) wav, sr = torchaudio.load(str(track))
assert sr == _SAMPLE_RATE, f"expected sample rate {_SAMPLE_RATE}, but got {sr}" if sr != _SAMPLE_RATE:
raise ValueError(f"expected sample rate {_SAMPLE_RATE}, but got {sr}")
if num_frames is None: if num_frames is None:
num_frames = wav.shape[-1] num_frames = wav.shape[-1]
else: else:
assert wav.shape[-1] == num_frames, "num_frames do not match across sources" if wav.shape[-1] != num_frames:
raise ValueError("num_frames do not match across sources")
wavs.append(wav) wavs.append(wav)
stacked = torch.stack(wavs) stacked = torch.stack(wavs)
......
...@@ -42,9 +42,11 @@ class QUESST14(Dataset): ...@@ -42,9 +42,11 @@ class QUESST14(Dataset):
language: Optional[str] = "nnenglish", language: Optional[str] = "nnenglish",
download: bool = False, download: bool = False,
) -> None: ) -> None:
assert subset in ["docs", "dev", "eval"], "`subset` must be one of ['docs', 'dev', 'eval']" if subset not in ["docs", "dev", "eval"]:
raise ValueError("`subset` must be one of ['docs', 'dev', 'eval']")
assert language is None or language in _LANGUAGES, f"`language` must be None or one of {str(_LANGUAGES)}" if language is not None and language not in _LANGUAGES:
raise ValueError(f"`language` must be None or one of {str(_LANGUAGES)}")
# Get string representation of 'root' # Get string representation of 'root'
root = os.fspath(root) root = os.fspath(root)
......
...@@ -79,9 +79,8 @@ class SPEECHCOMMANDS(Dataset): ...@@ -79,9 +79,8 @@ class SPEECHCOMMANDS(Dataset):
subset: Optional[str] = None, subset: Optional[str] = None,
) -> None: ) -> None:
assert subset is None or subset in ["training", "validation", "testing"], ( if subset is not None and subset not in ["training", "validation", "testing"]:
"When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}." raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")
)
if url in [ if url in [
"speech_commands_v0.01", "speech_commands_v0.01",
......
...@@ -137,7 +137,8 @@ class VoxCeleb1Identification(VoxCeleb1): ...@@ -137,7 +137,8 @@ class VoxCeleb1Identification(VoxCeleb1):
self, root: Union[str, Path], subset: str = "train", meta_url: str = _IDEN_SPLIT_URL, download: bool = False self, root: Union[str, Path], subset: str = "train", meta_url: str = _IDEN_SPLIT_URL, download: bool = False
) -> None: ) -> None:
super().__init__(root, download) super().__init__(root, download)
assert subset in ["train", "dev", "test"], "`subset` must be one of ['train', 'dev', 'test']" if subset not in ["train", "dev", "test"]:
raise ValueError("`subset` must be one of ['train', 'dev', 'test']")
# download the iden_split.txt to get the train, dev, test lists. # download the iden_split.txt to get the train, dev, test lists.
meta_list_path = os.path.join(root, os.path.basename(meta_url)) meta_list_path = os.path.join(root, os.path.basename(meta_url))
if not os.path.exists(meta_list_path): if not os.path.exists(meta_list_path):
...@@ -205,7 +206,8 @@ class VoxCeleb1Verification(VoxCeleb1): ...@@ -205,7 +206,8 @@ class VoxCeleb1Verification(VoxCeleb1):
file_id_spk2 = _get_file_id(file_path_spk2, self._ext_audio) file_id_spk2 = _get_file_id(file_path_spk2, self._ext_audio)
waveform_spk1, sample_rate = torchaudio.load(os.path.join(self._path, file_path_spk1)) waveform_spk1, sample_rate = torchaudio.load(os.path.join(self._path, file_path_spk1))
waveform_spk2, sample_rate2 = torchaudio.load(os.path.join(self._path, file_path_spk2)) waveform_spk2, sample_rate2 = torchaudio.load(os.path.join(self._path, file_path_spk2))
assert sample_rate == sample_rate2 if sample_rate != sample_rate2:
raise ValueError(f"`sample_rate` {sample_rate} is not equal to `sample_rate2` {sample_rate2}")
return (waveform_spk1, waveform_spk2, sample_rate, label, file_id_spk1, file_id_spk2) return (waveform_spk1, waveform_spk2, sample_rate, label, file_id_spk1, file_id_spk2)
def __len__(self) -> int: def __len__(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment