Commit 0f4e1e8c authored by Piyush Soni's avatar Piyush Soni Committed by Facebook GitHub Bot
Browse files

Replace assert with raise (#2579)

Summary:
`assert` is not executed when running in optimized mode.

This commit replaces all instances of "assert" in /fbcode/pytorch/audio/torchaudio/functional/functional.py

Pull Request resolved: https://github.com/pytorch/audio/pull/2579

Reviewed By: mthrok

Differential Revision: D38158280

fbshipit-source-id: f8d7fca1c8f9b3955c6ca312b16947eb12894d81
parent 5bf73b59
...@@ -262,8 +262,9 @@ def griffinlim( ...@@ -262,8 +262,9 @@ def griffinlim(
Returns: Returns:
Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given. Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given.
""" """
assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum) if not 0 <= momentum < 1:
assert momentum >= 0, "momentum={} < 0".format(momentum) raise ValueError("momentum must be in range [0, 1). Found: {}".format(momentum))
momentum = momentum / (1 + momentum) momentum = momentum / (1 + momentum)
# pack batch # pack batch
...@@ -609,14 +610,18 @@ def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor: ...@@ -609,14 +610,18 @@ def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor:
Tensor: The transformation matrix, to be right-multiplied to Tensor: The transformation matrix, to be right-multiplied to
row-wise data of size (``n_mels``, ``n_mfcc``). row-wise data of size (``n_mels``, ``n_mfcc``).
""" """
if norm is not None and norm != "ortho":
raise ValueError("norm must be either 'ortho' or None")
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n = torch.arange(float(n_mels)) n = torch.arange(float(n_mels))
k = torch.arange(float(n_mfcc)).unsqueeze(1) k = torch.arange(float(n_mfcc)).unsqueeze(1)
dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels) dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels)
if norm is None: if norm is None:
dct *= 2.0 dct *= 2.0
else: else:
assert norm == "ortho"
dct[0] *= 1.0 / math.sqrt(2.0) dct[0] *= 1.0 / math.sqrt(2.0)
dct *= math.sqrt(2.0 / float(n_mels)) dct *= math.sqrt(2.0 / float(n_mels))
return dct.t() return dct.t()
...@@ -878,7 +883,8 @@ def mask_along_axis( ...@@ -878,7 +883,8 @@ def mask_along_axis(
if axis == 1: if axis == 1:
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
assert mask_end - mask_start < mask_param if mask_end - mask_start >= mask_param:
raise ValueError("Number of columns to be masked should be less than mask_param")
specgram = specgram.masked_fill(mask, mask_value) specgram = specgram.masked_fill(mask, mask_value)
...@@ -922,7 +928,8 @@ def compute_deltas(specgram: Tensor, win_length: int = 5, mode: str = "replicate ...@@ -922,7 +928,8 @@ def compute_deltas(specgram: Tensor, win_length: int = 5, mode: str = "replicate
shape = specgram.size() shape = specgram.size()
specgram = specgram.reshape(1, -1, shape[-1]) specgram = specgram.reshape(1, -1, shape[-1])
assert win_length >= 3 if win_length < 3:
raise ValueError(f"Window length should be greater than or equal to 3. Found win_length {win_length}")
n = (win_length - 1) // 2 n = (win_length - 1) // 2
...@@ -1414,7 +1421,8 @@ def _get_sinc_resample_kernel( ...@@ -1414,7 +1421,8 @@ def _get_sinc_resample_kernel(
orig_freq = int(orig_freq) // gcd orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd new_freq = int(new_freq) // gcd
assert lowpass_filter_width > 0 if lowpass_filter_width <= 0:
raise ValueError("Low pass filter width should be positive.")
base_freq = min(orig_freq, new_freq) base_freq = min(orig_freq, new_freq)
# This will perform antialiasing filtering by removing the highest frequencies. # This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling # At first I thought I only needed this when downsampling, but when upsampling
...@@ -1540,7 +1548,8 @@ def resample( ...@@ -1540,7 +1548,8 @@ def resample(
Tensor: The waveform at the new frequency of dimension `(..., time).` Tensor: The waveform at the new frequency of dimension `(..., time).`
""" """
assert orig_freq > 0.0 and new_freq > 0.0 if orig_freq <= 0.0 or new_freq <= 0.0:
raise ValueError("Original frequency and desired frequecy should be positive")
if orig_freq == new_freq: if orig_freq == new_freq:
return waveform return waveform
...@@ -1817,10 +1826,11 @@ def psd( ...@@ -1817,10 +1826,11 @@ def psd(
psd = torch.einsum("...ct,...et->...tce", [specgram, specgram.conj()]) psd = torch.einsum("...ct,...et->...tce", [specgram, specgram.conj()])
if mask is not None: if mask is not None:
assert ( if mask.shape[:-1] != specgram.shape[:-2] or mask.shape[-1] != specgram.shape[-1]:
mask.shape[:-1] == specgram.shape[:-2] and mask.shape[-1] == specgram.shape[-1] raise ValueError(
), "The dimensions of mask except the channel dimension should be the same as specgram." "The dimensions of mask except the channel dimension should be the same as specgram."
f"Found {mask.shape} for mask and {specgram.shape} for specgram." f"Found {mask.shape} for mask and {specgram.shape} for specgram."
)
# Normalized mask along time dimension: # Normalized mask along time dimension:
if normalize: if normalize:
mask = mask / (mask.sum(dim=-1, keepdim=True) + eps) mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
...@@ -1844,8 +1854,10 @@ def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> t ...@@ -1844,8 +1854,10 @@ def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> t
Returns: Returns:
Tensor: The trace of the input Tensor. Tensor: The trace of the input Tensor.
""" """
assert input.ndim >= 2, "The dimension of the tensor must be at least 2." if input.ndim < 2:
assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same." raise ValueError("The dimension of the tensor must be at least 2.")
if input.shape[dim1] != input.shape[dim2]:
raise ValueError("The size of ``dim1`` and ``dim2`` must be the same.")
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2) input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
return input.sum(dim=-1) return input.sum(dim=-1)
...@@ -1880,20 +1892,22 @@ def _assert_psd_matrices(psd_s: torch.Tensor, psd_n: torch.Tensor) -> None: ...@@ -1880,20 +1892,22 @@ def _assert_psd_matrices(psd_s: torch.Tensor, psd_n: torch.Tensor) -> None:
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise. psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`. Tensor with dimensions `(..., freq, channel, channel)`.
""" """
assert ( if psd_s.ndim < 3 or psd_n.ndim < 3:
psd_s.ndim >= 3 and psd_n.ndim >= 3 raise ValueError(
), "Expected at least 3D Tensor (..., freq, channel, channel) for psd_s and psd_n." "Expected at least 3D Tensor (..., freq, channel, channel) for psd_s and psd_n. "
"Found {psd_s.shape} for psd_s and {psd_n.shape} for psd_n." f"Found {psd_s.shape} for psd_s and {psd_n.shape} for psd_n."
assert ( )
psd_s.is_complex() and psd_n.is_complex() if not (psd_s.is_complex() and psd_n.is_complex()):
), "The type of psd_s and psd_n must be ``torch.cfloat`` or ``torch.cdouble``." raise TypeError(
f"Found {psd_s.dtype} for psd_s and {psd_n.dtype} for psd_n." "The type of psd_s and psd_n must be ``torch.cfloat`` or ``torch.cdouble``. "
assert ( f"Found {psd_s.dtype} for psd_s and {psd_n.dtype} for psd_n."
psd_s.shape == psd_n.shape )
), f"The dimensions of psd_s and psd_n should be the same. Found {psd_s.shape} and {psd_n.shape}." if psd_s.shape != psd_n.shape:
assert ( raise ValueError(
psd_s.shape[-1] == psd_s.shape[-2] f"The dimensions of psd_s and psd_n should be the same. Found {psd_s.shape} and {psd_n.shape}."
), f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}." )
if psd_s.shape[-1] != psd_s.shape[-2]:
raise ValueError(f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}.")
def mvdr_weights_souden( def mvdr_weights_souden(
...@@ -2005,19 +2019,22 @@ def mvdr_weights_rtf( ...@@ -2005,19 +2019,22 @@ def mvdr_weights_rtf(
Returns: Returns:
torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`. torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
""" """
assert rtf.ndim >= 2, f"Expected at least 2D Tensor (..., freq, channel) for rtf. Found {rtf.shape}." if rtf.ndim < 2:
assert psd_n.ndim >= 3, f"Expected at least 3D Tensor (..., freq, channel, channel) for psd_n. Found {psd_n.shape}." raise ValueError(f"Expected at least 2D Tensor (..., freq, channel) for rtf. Found {rtf.shape}.")
assert ( if psd_n.ndim < 3:
rtf.is_complex() and psd_n.is_complex() raise ValueError(f"Expected at least 3D Tensor (..., freq, channel, channel) for psd_n. Found {psd_n.shape}.")
), "The type of rtf and psd_n must be ``torch.cfloat`` or ``torch.cdouble``." if not (rtf.is_complex() and psd_n.is_complex()):
f"Found {rtf.dtype} for rtf and {psd_n.dtype} for psd_n." raise TypeError(
assert ( "The type of rtf and psd_n must be ``torch.cfloat`` or ``torch.cdouble``. "
rtf.shape == psd_n.shape[:-1] f"Found {rtf.dtype} for rtf and {psd_n.dtype} for psd_n."
), "The dimensions of rtf and the dimensions withou the last dimension of psd_n should be the same." )
f"Found {rtf.shape} for rtf and {psd_n.shape} for psd_n." if rtf.shape != psd_n.shape[:-1]:
assert ( raise ValueError(
psd_n.shape[-1] == psd_n.shape[-2] "The dimensions of rtf and the dimensions withou the last dimension of psd_n should be the same. "
), f"The last two dimensions of psd_n should be the same. Found {psd_n.shape}." f"Found {rtf.shape} for rtf and {psd_n.shape} for psd_n."
)
if psd_n.shape[-1] != psd_n.shape[-2]:
raise ValueError(f"The last two dimensions of psd_n should be the same. Found {psd_n.shape}.")
if diagonal_loading: if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps) psd_n = _tik_reg(psd_n, reg=diag_eps)
...@@ -2056,10 +2073,10 @@ def rtf_evd(psd_s: Tensor) -> Tensor: ...@@ -2056,10 +2073,10 @@ def rtf_evd(psd_s: Tensor) -> Tensor:
Tensor: The estimated complex-valued RTF of target speech. Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)` Tensor of dimension `(..., freq, channel)`
""" """
assert psd_s.is_complex(), f"The type of psd_s must be ``torch.cfloat`` or ``torch.cdouble``. Found {psd_s.dtype}." if not psd_s.is_complex():
assert ( raise TypeError(f"The type of psd_s must be ``torch.cfloat`` or ``torch.cdouble``. Found {psd_s.dtype}.")
psd_s.shape[-1] == psd_s.shape[-2] if psd_s.shape[-1] != psd_s.shape[-2]:
), f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}." raise ValueError(f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}.")
_, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order _, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order
rtf = v[..., -1] # choose the eigenvector with max eigenvalue rtf = v[..., -1] # choose the eigenvector with max eigenvalue
return rtf return rtf
...@@ -2098,7 +2115,8 @@ def rtf_power( ...@@ -2098,7 +2115,8 @@ def rtf_power(
Tensor of dimension `(..., freq, channel)`. Tensor of dimension `(..., freq, channel)`.
""" """
_assert_psd_matrices(psd_s, psd_n) _assert_psd_matrices(psd_s, psd_n)
assert n_iter > 0, "The number of iteration must be greater than 0." if n_iter <= 0:
raise ValueError("The number of iteration must be greater than 0.")
# Apply diagonal loading to psd_n to improve robustness. # Apply diagonal loading to psd_n to improve robustness.
if diagonal_loading: if diagonal_loading:
...@@ -2150,15 +2168,18 @@ def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor: ...@@ -2150,15 +2168,18 @@ def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
Tensor: The single-channel complex-valued enhanced spectrum. Tensor: The single-channel complex-valued enhanced spectrum.
Tensor of dimension `(..., freq, time)` Tensor of dimension `(..., freq, time)`
""" """
assert ( if beamform_weights.shape[:-2] != specgram.shape[:-3]:
beamform_weights.shape[:-2] == specgram.shape[:-3] raise ValueError(
), "The dimensions except the last two dimensions of beamform_weights should be the same " "The dimensions except the last two dimensions of beamform_weights should be the same "
"as the dimensions except the last three dimensions of specgram." "as the dimensions except the last three dimensions of specgram. "
f"Found {beamform_weights.shape} for beamform_weights and {specgram.shape} for specgram." f"Found {beamform_weights.shape} for beamform_weights and {specgram.shape} for specgram."
assert ( )
beamform_weights.is_complex() and specgram.is_complex()
), "The type of beamform_weights and specgram must be ``torch.cfloat`` or ``torch.cdouble``." if not (beamform_weights.is_complex() and specgram.is_complex()):
f"Found {beamform_weights.dtype} for beamform_weights and {specgram.dtype} for specgram." raise TypeError(
"The type of beamform_weights and specgram must be ``torch.cfloat`` or ``torch.cdouble``. "
f"Found {beamform_weights.dtype} for beamform_weights and {specgram.dtype} for specgram."
)
# (..., freq, channel) x (..., channel, freq, time) -> (..., freq, time) # (..., freq, channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_weights.conj(), specgram]) specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_weights.conj(), specgram])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment