Commit 34ef7e9c authored by Son Dinh's avatar Son Dinh Committed by Facebook GitHub Bot
Browse files

Replace assert with raise in prototypes.models (#2578)

Summary:
This commit replaces the use of assert with `if ~ then raise` idiom,
So that they are executed even when Python is running in optimized mode.

Pull Request resolved: https://github.com/pytorch/audio/pull/2578

Reviewed By: mthrok

Differential Revision: D38158122

fbshipit-source-id: da561145a6e021238e9e9df10ab8d2d3a751fb69
parent 0f4e1e8c
...@@ -63,7 +63,8 @@ class _ConvolutionModule(torch.nn.Module): ...@@ -63,7 +63,8 @@ class _ConvolutionModule(torch.nn.Module):
def _split_right_context(self, utterance: torch.Tensor, right_context: torch.Tensor) -> torch.Tensor: def _split_right_context(self, utterance: torch.Tensor, right_context: torch.Tensor) -> torch.Tensor:
T, B, D = right_context.size() T, B, D = right_context.size()
assert T % self.right_context_length == 0 if T % self.right_context_length != 0:
raise ValueError("Tensor length should be divisible by its right context length")
num_segments = T // self.right_context_length num_segments = T // self.right_context_length
# (num_segments, right context length, B, D) # (num_segments, right context length, B, D)
right_context_segments = right_context.reshape(num_segments, self.right_context_length, B, D) right_context_segments = right_context.reshape(num_segments, self.right_context_length, B, D)
......
...@@ -162,7 +162,8 @@ class _HEncLayer(torch.nn.Module): ...@@ -162,7 +162,8 @@ class _HEncLayer(torch.nn.Module):
if self.empty: if self.empty:
return y return y
if inject is not None: if inject is not None:
assert inject.shape[-1] == y.shape[-1], "injection shapes do not align" if inject.shape[-1] != y.shape[-1]:
raise ValueError("Injection shapes do not align")
if inject.dim() == 3 and y.dim() == 4: if inject.dim() == 3 and y.dim() == 4:
inject = inject[:, :, None] inject = inject[:, :, None]
y = y + inject y = y + inject
...@@ -220,7 +221,8 @@ class _HDecLayer(torch.nn.Module): ...@@ -220,7 +221,8 @@ class _HDecLayer(torch.nn.Module):
if norm_type == "group_norm": if norm_type == "group_norm":
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
if pad: if pad:
assert (kernel_size - stride) % 2 == 0, "kernel size and stride do not align" if (kernel_size - stride) % 2 != 0:
raise ValueError("Kernel size and stride do not align")
pad = (kernel_size - stride) // 2 pad = (kernel_size - stride) // 2
else: else:
pad = 0 pad = 0
...@@ -279,14 +281,17 @@ class _HDecLayer(torch.nn.Module): ...@@ -279,14 +281,17 @@ class _HDecLayer(torch.nn.Module):
y = F.glu(self.norm1(self.rewrite(x)), dim=1) y = F.glu(self.norm1(self.rewrite(x)), dim=1)
else: else:
y = x y = x
assert skip is None, "skip must be none when empty is true." if skip is not None:
raise ValueError("Skip must be none when empty is true.")
z = self.norm2(self.conv_tr(y)) z = self.norm2(self.conv_tr(y))
if self.freq: if self.freq:
if self.pad: if self.pad:
z = z[..., self.pad : -self.pad, :] z = z[..., self.pad : -self.pad, :]
else: else:
z = z[..., self.pad : self.pad + length] z = z[..., self.pad : self.pad + length]
assert z.shape[-1] == length, "Last index of z must be equal to length" if z.shape[-1] != length:
raise ValueError("Last index of z must be equal to length")
if not self.last: if not self.last:
z = F.gelu(z) z = F.gelu(z)
...@@ -388,7 +393,8 @@ class HDemucs(torch.nn.Module): ...@@ -388,7 +393,8 @@ class HDemucs(torch.nn.Module):
stri = stride stri = stride
ker = kernel_size ker = kernel_size
if not freq: if not freq:
assert freqs == 1, "when freq is false, freqs must be 1" if freqs != 1:
raise ValueError("When freq is false, freqs must be 1.")
ker = time_stride * 2 ker = time_stride * 2
stri = time_stride stri = time_stride
...@@ -470,13 +476,15 @@ class HDemucs(torch.nn.Module): ...@@ -470,13 +476,15 @@ class HDemucs(torch.nn.Module):
# which is not supported by torch.stft. # which is not supported by torch.stft.
# Having all convolution operations follow this convention allow to easily # Having all convolution operations follow this convention allow to easily
# align the time and frequency branches later on. # align the time and frequency branches later on.
assert hl == nfft // 4, "hop length must be nfft // 4" if hl != nfft // 4:
raise ValueError("Hop length must be nfft // 4")
le = int(math.ceil(x.shape[-1] / hl)) le = int(math.ceil(x.shape[-1] / hl))
pad = hl // 2 * 3 pad = hl // 2 * 3
x = F.pad(x, [pad, pad + le * hl - x.shape[-1]], mode="reflect") x = F.pad(x, [pad, pad + le * hl - x.shape[-1]], mode="reflect")
z = _spectro(x, nfft, hl)[..., :-1, :] z = _spectro(x, nfft, hl)[..., :-1, :]
assert z.shape[-1] == le + 4, "spectrogram's last dimension must be 4 + input size divided by stride" if z.shape[-1] != le + 4:
raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride")
z = z[..., 2 : 2 + le] z = z[..., 2 : 2 + le]
return z return z
...@@ -590,16 +598,20 @@ class HDemucs(torch.nn.Module): ...@@ -590,16 +598,20 @@ class HDemucs(torch.nn.Module):
tdec = self.time_decoder[idx - offset] tdec = self.time_decoder[idx - offset]
length_t = lengths_t.pop(-1) length_t = lengths_t.pop(-1)
if tdec.empty: if tdec.empty:
assert pre.shape[2] == 1, "tdec empty is true, pre shape does not match " + pre.shape if pre.shape[2] != 1:
raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}")
pre = pre[:, :, 0] pre = pre[:, :, 0]
xt, _ = tdec(pre, None, length_t) xt, _ = tdec(pre, None, length_t)
else: else:
skip = saved_t.pop(-1) skip = saved_t.pop(-1)
xt, _ = tdec(xt, skip, length_t) xt, _ = tdec(xt, skip, length_t)
assert len(saved) == 0, "saved is not empty" if len(saved) != 0:
assert len(lengths_t) == 0, "lengths_t is not empty" raise AssertionError("saved is not empty")
assert len(saved_t) == 0, "saved_t is not empty" if len(lengths_t) != 0:
raise AssertionError("lengths_t is not empty")
if len(saved_t) != 0:
raise AssertionError("saved_t is not empty")
S = len(self.sources) S = len(self.sources)
x = x.view(B, S, -1, Fq, T) x = x.view(B, S, -1, Fq, T)
...@@ -650,7 +662,8 @@ class _DConv(torch.nn.Module): ...@@ -650,7 +662,8 @@ class _DConv(torch.nn.Module):
): ):
super().__init__() super().__init__()
assert kernel_size % 2 == 1, "kernel size should not be divisible by 2" if kernel_size % 2 == 0:
raise ValueError("Kernel size should not be divisible by 2")
self.channels = channels self.channels = channels
self.compress = compress self.compress = compress
self.depth = abs(depth) self.depth = abs(depth)
...@@ -781,7 +794,8 @@ class _LocalState(nn.Module): ...@@ -781,7 +794,8 @@ class _LocalState(nn.Module):
ndecay (int, optional): (default: 4) ndecay (int, optional): (default: 4)
""" """
super(_LocalState, self).__init__() super(_LocalState, self).__init__()
assert channels % heads == 0, "Channels must be divisible by heads." if channels % heads != 0:
raise ValueError("Channels must be divisible by heads.")
self.heads = heads self.heads = heads
self.ndecay = ndecay self.ndecay = ndecay
self.content = nn.Conv1d(channels, channels, 1) self.content = nn.Conv1d(channels, channels, 1)
...@@ -792,7 +806,8 @@ class _LocalState(nn.Module): ...@@ -792,7 +806,8 @@ class _LocalState(nn.Module):
if ndecay: if ndecay:
# Initialize decay close to zero (there is a sigmoid), for maximum initial window. # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
self.query_decay.weight.data *= 0.01 self.query_decay.weight.data *= 0.01
assert self.query_decay.bias is not None, "bias must not be None" if self.query_decay.bias is None:
raise ValueError("bias must not be None.")
self.query_decay.bias.data[:] = -2 self.query_decay.bias.data[:] = -2
self.proj = nn.Conv1d(channels + heads * 0, channels, 1) self.proj = nn.Conv1d(channels + heads * 0, channels, 1)
...@@ -874,7 +889,8 @@ def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: ...@@ -874,7 +889,8 @@ def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
tgt_length = (n_frames - 1) * stride + kernel_size tgt_length = (n_frames - 1) * stride + kernel_size
a = F.pad(input=a, pad=[0, tgt_length - length]) a = F.pad(input=a, pad=[0, tgt_length - length])
strides = [a.stride(dim) for dim in range(a.dim())] strides = [a.stride(dim) for dim in range(a.dim())]
assert strides[-1] == 1, "data should be contiguous" if strides[-1] != 1:
raise ValueError("Data should be contiguous.")
strides = strides[:-1] + [stride, 1] strides = strides[:-1] + [stride, 1]
shape.append(n_frames) shape.append(n_frames)
shape.append(kernel_size) shape.append(kernel_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment