Commit e502df01 authored by Ansh Nanda's avatar Ansh Nanda Committed by Facebook GitHub Bot
Browse files

Replace assert with raise in torchaudio.models (#2590)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2590

Converted assert checks for argument validation to if-else checks so that they are executed in optimized mode as well.

Reviewed By: mthrok

Differential Revision: D38211246

fbshipit-source-id: 922b5bcafe8214980e535527dd94c3345c1ff3e2
parent c26b38b2
......@@ -37,7 +37,8 @@ class _ConvolutionModule(torch.nn.Module):
use_group_norm: bool = False,
) -> None:
super().__init__()
assert (depthwise_kernel_size - 1) % 2 == 0, "depthwise_kernel_size must be odd to achieve 'SAME' padding."
if (depthwise_kernel_size - 1) % 2 != 0:
raise ValueError("depthwise_kernel_size must be odd to achieve 'SAME' padding.")
self.layer_norm = torch.nn.LayerNorm(input_dim)
self.sequential = torch.nn.Sequential(
torch.nn.Conv1d(
......
......@@ -61,7 +61,8 @@ def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers
def _gen_attention_mask_block(
col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
) -> torch.Tensor:
assert len(col_widths) == len(col_mask), "Length of col_widths must match that of col_mask"
if len(col_widths) != len(col_mask):
raise ValueError("Length of col_widths must match that of col_mask")
mask_block = [
torch.ones(num_rows, col_width, device=device)
......@@ -194,11 +195,12 @@ class _EmformerAttention(torch.nn.Module):
# Compute attention.
attention = torch.bmm(attention_probs, reshaped_value)
assert attention.shape == (
if attention.shape != (
B * self.num_heads,
T,
self.input_dim // self.num_heads,
)
):
raise AssertionError("Computed attention has incorrect dimensions")
attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim)
# Apply output projection.
......@@ -770,7 +772,8 @@ class _EmformerImpl(torch.nn.Module):
output states; list of lists of tensors representing internal state
generated in current invocation of ``infer``.
"""
assert input.size(1) == self.segment_length + self.right_context_length, (
if input.size(1) != self.segment_length + self.right_context_length:
raise ValueError(
"Per configured segment_length and right_context_length"
f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input"
f", but got {input.size(1)}."
......
......@@ -280,13 +280,13 @@ class RNNTBeamSearch(torch.nn.Module):
Returns:
List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
"""
assert input.dim() == 2 or (
input.dim() == 3 and input.shape[0] == 1
), "input must be of shape (T, D) or (1, T, D)"
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
raise ValueError("input must be of shape (T, D) or (1, T, D)")
if input.dim() == 2:
input = input.unsqueeze(0)
assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
if length.shape != () and length.shape != (1,):
raise ValueError("length must be of shape () or (1,)")
if input.dim() == 0:
input = input.unsqueeze(0)
......@@ -326,13 +326,13 @@ class RNNTBeamSearch(torch.nn.Module):
list of lists of tensors representing transcription network
internal state generated in current invocation.
"""
assert input.dim() == 2 or (
input.dim() == 3 and input.shape[0] == 1
), "input must be of shape (T, D) or (1, T, D)"
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
raise ValueError("input must be of shape (T, D) or (1, T, D)")
if input.dim() == 2:
input = input.unsqueeze(0)
assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
if length.shape != () and length.shape != (1,):
raise ValueError("length must be of shape () or (1,)")
if length.dim() == 0:
length = length.unsqueeze(0)
......
......@@ -83,7 +83,8 @@ def _get_conv1d_layer(
(torch.nn.Conv1d): The corresponding Conv1D layer.
"""
if padding is None:
assert kernel_size % 2 == 1
if kernel_size % 2 != 1:
raise ValueError("kernel_size must be odd")
padding = int(dilation * (kernel_size - 1) / 2)
conv1d = torch.nn.Conv1d(
......@@ -1033,7 +1034,6 @@ class Tacotron2(nn.Module):
lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype)
assert lengths is not None # For TorchScript compiler
embedded_inputs = self.embedding(tokens).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, lengths)
mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(encoder_outputs, lengths)
......
......@@ -512,7 +512,8 @@ def _get_feature_extractor(
- Large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61
"""
assert norm_mode in ["group_norm", "layer_norm"]
if norm_mode not in ["group_norm", "layer_norm"]:
raise ValueError("Invalid norm mode")
blocks = []
in_channels = 1
for i, (out_channels, kernel_size, stride) in enumerate(shapes):
......
......@@ -155,9 +155,10 @@ class HuBERTPretrainModel(Module):
self.wav2vec2 = wav2vec2
self.mask_generator = mask_generator
self.logit_generator = logit_generator
assert (
feature_grad_mult is None or 0.0 < feature_grad_mult < 1.0
), f"The value of `feature_grad_mult` must be ``None`` or between (0, 1). Found {feature_grad_mult}"
if feature_grad_mult is not None and not 0.0 < feature_grad_mult < 1.0:
raise ValueError(
f"The value of `feature_grad_mult` must be ``None``or between (0, 1). Found {feature_grad_mult}"
)
self.feature_grad_mult = feature_grad_mult
def forward(
......@@ -204,7 +205,8 @@ class HuBERTPretrainModel(Module):
x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths)
x, mask = self.mask_generator(x, padding_mask)
x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
assert x.shape[1] == labels.shape[1], "The length of label must match that of HuBERT model output"
if x.shape[1] != labels.shape[1]:
raise ValueError("The length of label must match that of HuBERT model output")
if padding_mask is not None:
mask_m = torch.logical_and(~padding_mask, mask)
mask_u = torch.logical_and(~padding_mask, ~mask_m)
......
......@@ -277,8 +277,10 @@ class WaveRNN(nn.Module):
Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
"""
assert waveform.size(1) == 1, "Require the input channel of waveform is 1"
assert specgram.size(1) == 1, "Require the input channel of specgram is 1"
if waveform.size(1) != 1:
raise ValueError("Require the input channel of waveform is 1")
if specgram.size(1) != 1:
raise ValueError("Require the input channel of specgram is 1")
# remove channel dimension until the end
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment