Commit e502df01 authored by Ansh Nanda's avatar Ansh Nanda Committed by Facebook GitHub Bot
Browse files

Replace assert with raise in torchaudio.models (#2590)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2590

Converted assert checks for argument validation to if-else checks so that they are executed in optimized mode as well.

Reviewed By: mthrok

Differential Revision: D38211246

fbshipit-source-id: 922b5bcafe8214980e535527dd94c3345c1ff3e2
parent c26b38b2
...@@ -37,7 +37,8 @@ class _ConvolutionModule(torch.nn.Module): ...@@ -37,7 +37,8 @@ class _ConvolutionModule(torch.nn.Module):
use_group_norm: bool = False, use_group_norm: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert (depthwise_kernel_size - 1) % 2 == 0, "depthwise_kernel_size must be odd to achieve 'SAME' padding." if (depthwise_kernel_size - 1) % 2 != 0:
raise ValueError("depthwise_kernel_size must be odd to achieve 'SAME' padding.")
self.layer_norm = torch.nn.LayerNorm(input_dim) self.layer_norm = torch.nn.LayerNorm(input_dim)
self.sequential = torch.nn.Sequential( self.sequential = torch.nn.Sequential(
torch.nn.Conv1d( torch.nn.Conv1d(
......
...@@ -61,7 +61,8 @@ def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers ...@@ -61,7 +61,8 @@ def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers
def _gen_attention_mask_block( def _gen_attention_mask_block(
col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
) -> torch.Tensor: ) -> torch.Tensor:
assert len(col_widths) == len(col_mask), "Length of col_widths must match that of col_mask" if len(col_widths) != len(col_mask):
raise ValueError("Length of col_widths must match that of col_mask")
mask_block = [ mask_block = [
torch.ones(num_rows, col_width, device=device) torch.ones(num_rows, col_width, device=device)
...@@ -194,11 +195,12 @@ class _EmformerAttention(torch.nn.Module): ...@@ -194,11 +195,12 @@ class _EmformerAttention(torch.nn.Module):
# Compute attention. # Compute attention.
attention = torch.bmm(attention_probs, reshaped_value) attention = torch.bmm(attention_probs, reshaped_value)
assert attention.shape == ( if attention.shape != (
B * self.num_heads, B * self.num_heads,
T, T,
self.input_dim // self.num_heads, self.input_dim // self.num_heads,
) ):
raise AssertionError("Computed attention has incorrect dimensions")
attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim) attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim)
# Apply output projection. # Apply output projection.
...@@ -770,11 +772,12 @@ class _EmformerImpl(torch.nn.Module): ...@@ -770,11 +772,12 @@ class _EmformerImpl(torch.nn.Module):
output states; list of lists of tensors representing internal state output states; list of lists of tensors representing internal state
generated in current invocation of ``infer``. generated in current invocation of ``infer``.
""" """
assert input.size(1) == self.segment_length + self.right_context_length, ( if input.size(1) != self.segment_length + self.right_context_length:
"Per configured segment_length and right_context_length" raise ValueError(
f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input" "Per configured segment_length and right_context_length"
f", but got {input.size(1)}." f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input"
) f", but got {input.size(1)}."
)
input = input.permute(1, 0, 2) input = input.permute(1, 0, 2)
right_context_start_idx = input.size(0) - self.right_context_length right_context_start_idx = input.size(0) - self.right_context_length
right_context = input[right_context_start_idx:] right_context = input[right_context_start_idx:]
......
...@@ -280,13 +280,13 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -280,13 +280,13 @@ class RNNTBeamSearch(torch.nn.Module):
Returns: Returns:
List[Hypothesis]: top-``beam_width`` hypotheses found by beam search. List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
""" """
assert input.dim() == 2 or ( if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
input.dim() == 3 and input.shape[0] == 1 raise ValueError("input must be of shape (T, D) or (1, T, D)")
), "input must be of shape (T, D) or (1, T, D)"
if input.dim() == 2: if input.dim() == 2:
input = input.unsqueeze(0) input = input.unsqueeze(0)
assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)" if length.shape != () and length.shape != (1,):
raise ValueError("length must be of shape () or (1,)")
if input.dim() == 0: if input.dim() == 0:
input = input.unsqueeze(0) input = input.unsqueeze(0)
...@@ -326,13 +326,13 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -326,13 +326,13 @@ class RNNTBeamSearch(torch.nn.Module):
list of lists of tensors representing transcription network list of lists of tensors representing transcription network
internal state generated in current invocation. internal state generated in current invocation.
""" """
assert input.dim() == 2 or ( if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
input.dim() == 3 and input.shape[0] == 1 raise ValueError("input must be of shape (T, D) or (1, T, D)")
), "input must be of shape (T, D) or (1, T, D)"
if input.dim() == 2: if input.dim() == 2:
input = input.unsqueeze(0) input = input.unsqueeze(0)
assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)" if length.shape != () and length.shape != (1,):
raise ValueError("length must be of shape () or (1,)")
if length.dim() == 0: if length.dim() == 0:
length = length.unsqueeze(0) length = length.unsqueeze(0)
......
...@@ -83,7 +83,8 @@ def _get_conv1d_layer( ...@@ -83,7 +83,8 @@ def _get_conv1d_layer(
(torch.nn.Conv1d): The corresponding Conv1D layer. (torch.nn.Conv1d): The corresponding Conv1D layer.
""" """
if padding is None: if padding is None:
assert kernel_size % 2 == 1 if kernel_size % 2 != 1:
raise ValueError("kernel_size must be odd")
padding = int(dilation * (kernel_size - 1) / 2) padding = int(dilation * (kernel_size - 1) / 2)
conv1d = torch.nn.Conv1d( conv1d = torch.nn.Conv1d(
...@@ -1033,7 +1034,6 @@ class Tacotron2(nn.Module): ...@@ -1033,7 +1034,6 @@ class Tacotron2(nn.Module):
lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype) lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype)
assert lengths is not None # For TorchScript compiler assert lengths is not None # For TorchScript compiler
embedded_inputs = self.embedding(tokens).transpose(1, 2) embedded_inputs = self.embedding(tokens).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, lengths) encoder_outputs = self.encoder(embedded_inputs, lengths)
mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(encoder_outputs, lengths) mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(encoder_outputs, lengths)
......
...@@ -512,7 +512,8 @@ def _get_feature_extractor( ...@@ -512,7 +512,8 @@ def _get_feature_extractor(
- Large: - Large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61 https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61
""" """
assert norm_mode in ["group_norm", "layer_norm"] if norm_mode not in ["group_norm", "layer_norm"]:
raise ValueError("Invalid norm mode")
blocks = [] blocks = []
in_channels = 1 in_channels = 1
for i, (out_channels, kernel_size, stride) in enumerate(shapes): for i, (out_channels, kernel_size, stride) in enumerate(shapes):
......
...@@ -155,9 +155,10 @@ class HuBERTPretrainModel(Module): ...@@ -155,9 +155,10 @@ class HuBERTPretrainModel(Module):
self.wav2vec2 = wav2vec2 self.wav2vec2 = wav2vec2
self.mask_generator = mask_generator self.mask_generator = mask_generator
self.logit_generator = logit_generator self.logit_generator = logit_generator
assert ( if feature_grad_mult is not None and not 0.0 < feature_grad_mult < 1.0:
feature_grad_mult is None or 0.0 < feature_grad_mult < 1.0 raise ValueError(
), f"The value of `feature_grad_mult` must be ``None`` or between (0, 1). Found {feature_grad_mult}" f"The value of `feature_grad_mult` must be ``None``or between (0, 1). Found {feature_grad_mult}"
)
self.feature_grad_mult = feature_grad_mult self.feature_grad_mult = feature_grad_mult
def forward( def forward(
...@@ -204,7 +205,8 @@ class HuBERTPretrainModel(Module): ...@@ -204,7 +205,8 @@ class HuBERTPretrainModel(Module):
x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths) x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths)
x, mask = self.mask_generator(x, padding_mask) x, mask = self.mask_generator(x, padding_mask)
x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask) x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
assert x.shape[1] == labels.shape[1], "The length of label must match that of HuBERT model output" if x.shape[1] != labels.shape[1]:
raise ValueError("The length of label must match that of HuBERT model output")
if padding_mask is not None: if padding_mask is not None:
mask_m = torch.logical_and(~padding_mask, mask) mask_m = torch.logical_and(~padding_mask, mask)
mask_u = torch.logical_and(~padding_mask, ~mask_m) mask_u = torch.logical_and(~padding_mask, ~mask_m)
......
...@@ -277,8 +277,10 @@ class WaveRNN(nn.Module): ...@@ -277,8 +277,10 @@ class WaveRNN(nn.Module):
Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes) Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
""" """
assert waveform.size(1) == 1, "Require the input channel of waveform is 1" if waveform.size(1) != 1:
assert specgram.size(1) == 1, "Require the input channel of specgram is 1" raise ValueError("Require the input channel of waveform is 1")
if specgram.size(1) != 1:
raise ValueError("Require the input channel of specgram is 1")
# remove channel dimension until the end # remove channel dimension until the end
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1) waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment