"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "42dd5af51ec3f345018b2206a1656bb09718af67"
Commit 34ef7e9c authored by Son Dinh's avatar Son Dinh Committed by Facebook GitHub Bot
Browse files

Replace assert with raise in prototypes.models (#2578)

Summary:
This commit replaces the use of assert with `if ~ then raise` idiom,
So that they are executed even when Python is running in optimized mode.

Pull Request resolved: https://github.com/pytorch/audio/pull/2578

Reviewed By: mthrok

Differential Revision: D38158122

fbshipit-source-id: da561145a6e021238e9e9df10ab8d2d3a751fb69
parent 0f4e1e8c
......@@ -63,7 +63,8 @@ class _ConvolutionModule(torch.nn.Module):
def _split_right_context(self, utterance: torch.Tensor, right_context: torch.Tensor) -> torch.Tensor:
T, B, D = right_context.size()
assert T % self.right_context_length == 0
if T % self.right_context_length != 0:
raise ValueError("Tensor length should be divisible by its right context length")
num_segments = T // self.right_context_length
# (num_segments, right context length, B, D)
right_context_segments = right_context.reshape(num_segments, self.right_context_length, B, D)
......
......@@ -162,7 +162,8 @@ class _HEncLayer(torch.nn.Module):
if self.empty:
return y
if inject is not None:
assert inject.shape[-1] == y.shape[-1], "injection shapes do not align"
if inject.shape[-1] != y.shape[-1]:
raise ValueError("Injection shapes do not align")
if inject.dim() == 3 and y.dim() == 4:
inject = inject[:, :, None]
y = y + inject
......@@ -220,7 +221,8 @@ class _HDecLayer(torch.nn.Module):
if norm_type == "group_norm":
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
if pad:
assert (kernel_size - stride) % 2 == 0, "kernel size and stride do not align"
if (kernel_size - stride) % 2 != 0:
raise ValueError("Kernel size and stride do not align")
pad = (kernel_size - stride) // 2
else:
pad = 0
......@@ -279,14 +281,17 @@ class _HDecLayer(torch.nn.Module):
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
else:
y = x
assert skip is None, "skip must be none when empty is true."
if skip is not None:
raise ValueError("Skip must be none when empty is true.")
z = self.norm2(self.conv_tr(y))
if self.freq:
if self.pad:
z = z[..., self.pad : -self.pad, :]
else:
z = z[..., self.pad : self.pad + length]
assert z.shape[-1] == length, "Last index of z must be equal to length"
if z.shape[-1] != length:
raise ValueError("Last index of z must be equal to length")
if not self.last:
z = F.gelu(z)
......@@ -388,7 +393,8 @@ class HDemucs(torch.nn.Module):
stri = stride
ker = kernel_size
if not freq:
assert freqs == 1, "when freq is false, freqs must be 1"
if freqs != 1:
raise ValueError("When freq is false, freqs must be 1.")
ker = time_stride * 2
stri = time_stride
......@@ -470,13 +476,15 @@ class HDemucs(torch.nn.Module):
# which is not supported by torch.stft.
# Having all convolution operations follow this convention allow to easily
# align the time and frequency branches later on.
assert hl == nfft // 4, "hop length must be nfft // 4"
if hl != nfft // 4:
raise ValueError("Hop length must be nfft // 4")
le = int(math.ceil(x.shape[-1] / hl))
pad = hl // 2 * 3
x = F.pad(x, [pad, pad + le * hl - x.shape[-1]], mode="reflect")
z = _spectro(x, nfft, hl)[..., :-1, :]
assert z.shape[-1] == le + 4, "spectrogram's last dimension must be 4 + input size divided by stride"
if z.shape[-1] != le + 4:
raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride")
z = z[..., 2 : 2 + le]
return z
......@@ -590,16 +598,20 @@ class HDemucs(torch.nn.Module):
tdec = self.time_decoder[idx - offset]
length_t = lengths_t.pop(-1)
if tdec.empty:
assert pre.shape[2] == 1, "tdec empty is true, pre shape does not match " + pre.shape
if pre.shape[2] != 1:
raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}")
pre = pre[:, :, 0]
xt, _ = tdec(pre, None, length_t)
else:
skip = saved_t.pop(-1)
xt, _ = tdec(xt, skip, length_t)
assert len(saved) == 0, "saved is not empty"
assert len(lengths_t) == 0, "lengths_t is not empty"
assert len(saved_t) == 0, "saved_t is not empty"
if len(saved) != 0:
raise AssertionError("saved is not empty")
if len(lengths_t) != 0:
raise AssertionError("lengths_t is not empty")
if len(saved_t) != 0:
raise AssertionError("saved_t is not empty")
S = len(self.sources)
x = x.view(B, S, -1, Fq, T)
......@@ -650,7 +662,8 @@ class _DConv(torch.nn.Module):
):
super().__init__()
assert kernel_size % 2 == 1, "kernel size should not be divisible by 2"
if kernel_size % 2 == 0:
raise ValueError("Kernel size should not be divisible by 2")
self.channels = channels
self.compress = compress
self.depth = abs(depth)
......@@ -781,7 +794,8 @@ class _LocalState(nn.Module):
ndecay (int, optional): (default: 4)
"""
super(_LocalState, self).__init__()
assert channels % heads == 0, "Channels must be divisible by heads."
if channels % heads != 0:
raise ValueError("Channels must be divisible by heads.")
self.heads = heads
self.ndecay = ndecay
self.content = nn.Conv1d(channels, channels, 1)
......@@ -792,7 +806,8 @@ class _LocalState(nn.Module):
if ndecay:
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
self.query_decay.weight.data *= 0.01
assert self.query_decay.bias is not None, "bias must not be None"
if self.query_decay.bias is None:
raise ValueError("bias must not be None.")
self.query_decay.bias.data[:] = -2
self.proj = nn.Conv1d(channels + heads * 0, channels, 1)
......@@ -874,7 +889,8 @@ def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
tgt_length = (n_frames - 1) * stride + kernel_size
a = F.pad(input=a, pad=[0, tgt_length - length])
strides = [a.stride(dim) for dim in range(a.dim())]
assert strides[-1] == 1, "data should be contiguous"
if strides[-1] != 1:
raise ValueError("Data should be contiguous.")
strides = strides[:-1] + [stride, 1]
shape.append(n_frames)
shape.append(kernel_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment