Commit 0f4e1e8c authored by Piyush Soni's avatar Piyush Soni Committed by Facebook GitHub Bot
Browse files

Replace assert with raise (#2579)

Summary:
`assert` is not executed when running in optimized mode.

This commit replaces all instances of "assert" in /fbcode/pytorch/audio/torchaudio/functional/functional.py

Pull Request resolved: https://github.com/pytorch/audio/pull/2579

Reviewed By: mthrok

Differential Revision: D38158280

fbshipit-source-id: f8d7fca1c8f9b3955c6ca312b16947eb12894d81
parent 5bf73b59
......@@ -262,8 +262,9 @@ def griffinlim(
Returns:
Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given.
"""
assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum)
assert momentum >= 0, "momentum={} < 0".format(momentum)
if not 0 <= momentum < 1:
raise ValueError("momentum must be in range [0, 1). Found: {}".format(momentum))
momentum = momentum / (1 + momentum)
# pack batch
......@@ -609,14 +610,18 @@ def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor:
Tensor: The transformation matrix, to be right-multiplied to
row-wise data of size (``n_mels``, ``n_mfcc``).
"""
if norm is not None and norm != "ortho":
raise ValueError("norm must be either 'ortho' or None")
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n = torch.arange(float(n_mels))
k = torch.arange(float(n_mfcc)).unsqueeze(1)
dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels)
if norm is None:
dct *= 2.0
else:
assert norm == "ortho"
dct[0] *= 1.0 / math.sqrt(2.0)
dct *= math.sqrt(2.0 / float(n_mels))
return dct.t()
......@@ -878,7 +883,8 @@ def mask_along_axis(
if axis == 1:
mask = mask.unsqueeze(-1)
assert mask_end - mask_start < mask_param
if mask_end - mask_start >= mask_param:
raise ValueError("Number of columns to be masked should be less than mask_param")
specgram = specgram.masked_fill(mask, mask_value)
......@@ -922,7 +928,8 @@ def compute_deltas(specgram: Tensor, win_length: int = 5, mode: str = "replicate
shape = specgram.size()
specgram = specgram.reshape(1, -1, shape[-1])
assert win_length >= 3
if win_length < 3:
raise ValueError(f"Window length should be greater than or equal to 3. Found win_length {win_length}")
n = (win_length - 1) // 2
......@@ -1414,7 +1421,8 @@ def _get_sinc_resample_kernel(
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd
assert lowpass_filter_width > 0
if lowpass_filter_width <= 0:
raise ValueError("Low pass filter width should be positive.")
base_freq = min(orig_freq, new_freq)
# This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
......@@ -1540,7 +1548,8 @@ def resample(
Tensor: The waveform at the new frequency of dimension `(..., time).`
"""
assert orig_freq > 0.0 and new_freq > 0.0
if orig_freq <= 0.0 or new_freq <= 0.0:
raise ValueError("Original frequency and desired frequecy should be positive")
if orig_freq == new_freq:
return waveform
......@@ -1817,10 +1826,11 @@ def psd(
psd = torch.einsum("...ct,...et->...tce", [specgram, specgram.conj()])
if mask is not None:
assert (
mask.shape[:-1] == specgram.shape[:-2] and mask.shape[-1] == specgram.shape[-1]
), "The dimensions of mask except the channel dimension should be the same as specgram."
f"Found {mask.shape} for mask and {specgram.shape} for specgram."
if mask.shape[:-1] != specgram.shape[:-2] or mask.shape[-1] != specgram.shape[-1]:
raise ValueError(
"The dimensions of mask except the channel dimension should be the same as specgram."
f"Found {mask.shape} for mask and {specgram.shape} for specgram."
)
# Normalized mask along time dimension:
if normalize:
mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
......@@ -1844,8 +1854,10 @@ def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> t
Returns:
Tensor: The trace of the input Tensor.
"""
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
if input.ndim < 2:
raise ValueError("The dimension of the tensor must be at least 2.")
if input.shape[dim1] != input.shape[dim2]:
raise ValueError("The size of ``dim1`` and ``dim2`` must be the same.")
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
return input.sum(dim=-1)
......@@ -1880,20 +1892,22 @@ def _assert_psd_matrices(psd_s: torch.Tensor, psd_n: torch.Tensor) -> None:
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
"""
assert (
psd_s.ndim >= 3 and psd_n.ndim >= 3
), "Expected at least 3D Tensor (..., freq, channel, channel) for psd_s and psd_n."
"Found {psd_s.shape} for psd_s and {psd_n.shape} for psd_n."
assert (
psd_s.is_complex() and psd_n.is_complex()
), "The type of psd_s and psd_n must be ``torch.cfloat`` or ``torch.cdouble``."
f"Found {psd_s.dtype} for psd_s and {psd_n.dtype} for psd_n."
assert (
psd_s.shape == psd_n.shape
), f"The dimensions of psd_s and psd_n should be the same. Found {psd_s.shape} and {psd_n.shape}."
assert (
psd_s.shape[-1] == psd_s.shape[-2]
), f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}."
if psd_s.ndim < 3 or psd_n.ndim < 3:
raise ValueError(
"Expected at least 3D Tensor (..., freq, channel, channel) for psd_s and psd_n. "
f"Found {psd_s.shape} for psd_s and {psd_n.shape} for psd_n."
)
if not (psd_s.is_complex() and psd_n.is_complex()):
raise TypeError(
"The type of psd_s and psd_n must be ``torch.cfloat`` or ``torch.cdouble``. "
f"Found {psd_s.dtype} for psd_s and {psd_n.dtype} for psd_n."
)
if psd_s.shape != psd_n.shape:
raise ValueError(
f"The dimensions of psd_s and psd_n should be the same. Found {psd_s.shape} and {psd_n.shape}."
)
if psd_s.shape[-1] != psd_s.shape[-2]:
raise ValueError(f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}.")
def mvdr_weights_souden(
......@@ -2005,19 +2019,22 @@ def mvdr_weights_rtf(
Returns:
torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
"""
assert rtf.ndim >= 2, f"Expected at least 2D Tensor (..., freq, channel) for rtf. Found {rtf.shape}."
assert psd_n.ndim >= 3, f"Expected at least 3D Tensor (..., freq, channel, channel) for psd_n. Found {psd_n.shape}."
assert (
rtf.is_complex() and psd_n.is_complex()
), "The type of rtf and psd_n must be ``torch.cfloat`` or ``torch.cdouble``."
f"Found {rtf.dtype} for rtf and {psd_n.dtype} for psd_n."
assert (
rtf.shape == psd_n.shape[:-1]
), "The dimensions of rtf and the dimensions withou the last dimension of psd_n should be the same."
f"Found {rtf.shape} for rtf and {psd_n.shape} for psd_n."
assert (
psd_n.shape[-1] == psd_n.shape[-2]
), f"The last two dimensions of psd_n should be the same. Found {psd_n.shape}."
if rtf.ndim < 2:
raise ValueError(f"Expected at least 2D Tensor (..., freq, channel) for rtf. Found {rtf.shape}.")
if psd_n.ndim < 3:
raise ValueError(f"Expected at least 3D Tensor (..., freq, channel, channel) for psd_n. Found {psd_n.shape}.")
if not (rtf.is_complex() and psd_n.is_complex()):
raise TypeError(
"The type of rtf and psd_n must be ``torch.cfloat`` or ``torch.cdouble``. "
f"Found {rtf.dtype} for rtf and {psd_n.dtype} for psd_n."
)
if rtf.shape != psd_n.shape[:-1]:
raise ValueError(
"The dimensions of rtf and the dimensions withou the last dimension of psd_n should be the same. "
f"Found {rtf.shape} for rtf and {psd_n.shape} for psd_n."
)
if psd_n.shape[-1] != psd_n.shape[-2]:
raise ValueError(f"The last two dimensions of psd_n should be the same. Found {psd_n.shape}.")
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps)
......@@ -2056,10 +2073,10 @@ def rtf_evd(psd_s: Tensor) -> Tensor:
Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)`
"""
assert psd_s.is_complex(), f"The type of psd_s must be ``torch.cfloat`` or ``torch.cdouble``. Found {psd_s.dtype}."
assert (
psd_s.shape[-1] == psd_s.shape[-2]
), f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}."
if not psd_s.is_complex():
raise TypeError(f"The type of psd_s must be ``torch.cfloat`` or ``torch.cdouble``. Found {psd_s.dtype}.")
if psd_s.shape[-1] != psd_s.shape[-2]:
raise ValueError(f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}.")
_, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order
rtf = v[..., -1] # choose the eigenvector with max eigenvalue
return rtf
......@@ -2098,7 +2115,8 @@ def rtf_power(
Tensor of dimension `(..., freq, channel)`.
"""
_assert_psd_matrices(psd_s, psd_n)
assert n_iter > 0, "The number of iteration must be greater than 0."
if n_iter <= 0:
raise ValueError("The number of iteration must be greater than 0.")
# Apply diagonal loading to psd_n to improve robustness.
if diagonal_loading:
......@@ -2150,15 +2168,18 @@ def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
Tensor: The single-channel complex-valued enhanced spectrum.
Tensor of dimension `(..., freq, time)`
"""
assert (
beamform_weights.shape[:-2] == specgram.shape[:-3]
), "The dimensions except the last two dimensions of beamform_weights should be the same "
"as the dimensions except the last three dimensions of specgram."
f"Found {beamform_weights.shape} for beamform_weights and {specgram.shape} for specgram."
assert (
beamform_weights.is_complex() and specgram.is_complex()
), "The type of beamform_weights and specgram must be ``torch.cfloat`` or ``torch.cdouble``."
f"Found {beamform_weights.dtype} for beamform_weights and {specgram.dtype} for specgram."
if beamform_weights.shape[:-2] != specgram.shape[:-3]:
raise ValueError(
"The dimensions except the last two dimensions of beamform_weights should be the same "
"as the dimensions except the last three dimensions of specgram. "
f"Found {beamform_weights.shape} for beamform_weights and {specgram.shape} for specgram."
)
if not (beamform_weights.is_complex() and specgram.is_complex()):
raise TypeError(
"The type of beamform_weights and specgram must be ``torch.cfloat`` or ``torch.cdouble``. "
f"Found {beamform_weights.dtype} for beamform_weights and {specgram.dtype} for specgram."
)
# (..., freq, channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_weights.conj(), specgram])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment