Commit 77a2baa8 authored by Omkar Vichare's avatar Omkar Vichare Committed by Facebook GitHub Bot
Browse files

Replace assert statements with raise in transforms (#2599)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2599

Bootcamp task T127107566.
Replacing assert statements  with if ... then raise so can be run in optimized mode

Reviewed By: mthrok

Differential Revision: D38370108

fbshipit-source-id: 74eaf5b72c511b62ddbb8e0e3b0ed638ad49e4f2
parent 6ecc11c2
...@@ -172,11 +172,14 @@ class MVDR(torch.nn.Module): ...@@ -172,11 +172,14 @@ class MVDR(torch.nn.Module):
online: bool = False, online: bool = False,
): ):
super().__init__() super().__init__()
assert solution in [ if solution not in [
"ref_channel", "ref_channel",
"stv_evd", "stv_evd",
"stv_power", "stv_power",
], "Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]." ]:
raise ValueError(
"`solution` must be one of ['ref_channel', 'stv_evd', 'stv_power']. Given {}".format(solution)
)
self.ref_channel = ref_channel self.ref_channel = ref_channel
self.solution = solution self.solution = solution
self.multi_mask = multi_mask self.multi_mask = multi_mask
......
...@@ -256,8 +256,8 @@ class GriffinLim(torch.nn.Module): ...@@ -256,8 +256,8 @@ class GriffinLim(torch.nn.Module):
) -> None: ) -> None:
super(GriffinLim, self).__init__() super(GriffinLim, self).__init__()
assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum) if not (0 <= momentum < 1):
assert momentum >= 0, "momentum={} < 0".format(momentum) raise ValueError("momentum must be in the range [0, 1). Found: {}".format(momentum))
self.n_fft = n_fft self.n_fft = n_fft
self.n_iter = n_iter self.n_iter = n_iter
...@@ -379,7 +379,9 @@ class MelScale(torch.nn.Module): ...@@ -379,7 +379,9 @@ class MelScale(torch.nn.Module):
self.norm = norm self.norm = norm
self.mel_scale = mel_scale self.mel_scale = mel_scale
assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(f_min, self.f_max) if f_min > self.f_max:
raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale)
self.register_buffer("fb", fb) self.register_buffer("fb", fb)
...@@ -456,7 +458,8 @@ class InverseMelScale(torch.nn.Module): ...@@ -456,7 +458,8 @@ class InverseMelScale(torch.nn.Module):
self.tolerance_change = tolerance_change self.tolerance_change = tolerance_change
self.sgdargs = sgdargs or {"lr": 0.1, "momentum": 0.9} self.sgdargs = sgdargs or {"lr": 0.1, "momentum": 0.9}
assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(f_min, self.f_max) if f_min > self.f_max:
raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale) fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale)
self.register_buffer("fb", fb) self.register_buffer("fb", fb)
...@@ -476,7 +479,8 @@ class InverseMelScale(torch.nn.Module): ...@@ -476,7 +479,8 @@ class InverseMelScale(torch.nn.Module):
n_mels, time = shape[-2], shape[-1] n_mels, time = shape[-2], shape[-1]
freq, _ = self.fb.size() # (freq, n_mels) freq, _ = self.fb.size() # (freq, n_mels)
melspec = melspec.transpose(-1, -2) melspec = melspec.transpose(-1, -2)
assert self.n_mels == n_mels if self.n_mels != n_mels:
raise ValueError("Expected an input with {} mel bins. Found: {}".format(self.n_mels, n_mels))
specgram = torch.rand( specgram = torch.rand(
melspec.size()[0], time, freq, requires_grad=True, dtype=melspec.dtype, device=melspec.device melspec.size()[0], time, freq, requires_grad=True, dtype=melspec.dtype, device=melspec.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment