Commit 77a2baa8 authored by Omkar Vichare's avatar Omkar Vichare Committed by Facebook GitHub Bot
Browse files

Replace assert statements with raise in transforms (#2599)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2599

Bootcamp task T127107566.
Replacing assert statements  with if ... then raise so can be run in optimized mode

Reviewed By: mthrok

Differential Revision: D38370108

fbshipit-source-id: 74eaf5b72c511b62ddbb8e0e3b0ed638ad49e4f2
parent 6ecc11c2
......@@ -172,11 +172,14 @@ class MVDR(torch.nn.Module):
online: bool = False,
):
super().__init__()
assert solution in [
if solution not in [
"ref_channel",
"stv_evd",
"stv_power",
], "Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]."
]:
raise ValueError(
"`solution` must be one of ['ref_channel', 'stv_evd', 'stv_power']. Given {}".format(solution)
)
self.ref_channel = ref_channel
self.solution = solution
self.multi_mask = multi_mask
......
......@@ -256,8 +256,8 @@ class GriffinLim(torch.nn.Module):
) -> None:
super(GriffinLim, self).__init__()
assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum)
assert momentum >= 0, "momentum={} < 0".format(momentum)
if not (0 <= momentum < 1):
raise ValueError("momentum must be in the range [0, 1). Found: {}".format(momentum))
self.n_fft = n_fft
self.n_iter = n_iter
......@@ -379,7 +379,9 @@ class MelScale(torch.nn.Module):
self.norm = norm
self.mel_scale = mel_scale
assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(f_min, self.f_max)
if f_min > self.f_max:
raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale)
self.register_buffer("fb", fb)
......@@ -456,7 +458,8 @@ class InverseMelScale(torch.nn.Module):
self.tolerance_change = tolerance_change
self.sgdargs = sgdargs or {"lr": 0.1, "momentum": 0.9}
assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(f_min, self.f_max)
if f_min > self.f_max:
raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale)
self.register_buffer("fb", fb)
......@@ -476,7 +479,8 @@ class InverseMelScale(torch.nn.Module):
n_mels, time = shape[-2], shape[-1]
freq, _ = self.fb.size() # (freq, n_mels)
melspec = melspec.transpose(-1, -2)
assert self.n_mels == n_mels
if self.n_mels != n_mels:
raise ValueError("Expected an input with {} mel bins. Found: {}".format(self.n_mels, n_mels))
specgram = torch.rand(
melspec.size()[0], time, freq, requires_grad=True, dtype=melspec.dtype, device=melspec.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment