Commit fb4eb981 authored by Kunal Upadya's avatar Kunal Upadya Committed by Facebook GitHub Bot
Browse files

Fixed argument validation in TorchAudio filtering (#2609)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2609

Converted argument validations in torchaudio/functional/filtering from assert based validation to the preferred if-then raise validation. Added specific error messages in all cases.

Reviewed By: mthrok

Differential Revision: D38515029

fbshipit-source-id: 6c644a042f86c6feb2bbe8bd02fdb484fe27fae9
parent 733ca909
......@@ -932,7 +932,8 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T
try:
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
except RuntimeError as err:
assert str(err) == "No such operator torchaudio::_lfilter_core_loop"
if str(err) != "No such operator torchaudio::_lfilter_core_loop":
raise
_lfilter_core_cpu_loop = _lfilter_core_generic_loop
......@@ -942,14 +943,24 @@ def _lfilter_core(
b_coeffs: Tensor,
) -> Tensor:
assert a_coeffs.size() == b_coeffs.size()
assert len(waveform.size()) == 3
assert waveform.device == a_coeffs.device
assert b_coeffs.device == a_coeffs.device
if a_coeffs.size() != b_coeffs.size():
raise ValueError(
"Expected coeffs to be the same size."
f"Found a_coeffs size: {a_coeffs.size()}, b_coeffs size: {b_coeffs.size()}"
)
if waveform.ndim != 3:
raise ValueError(f"Expected waveform to be 3 dimensional. Found: {waveform.ndim}")
if not (waveform.device == a_coeffs.device == b_coeffs.device):
raise ValueError(
"Expected waveform and coeffs to be on the same device."
f"Found: waveform device:{waveform.device}, a_coeffs device: {a_coeffs.device}, "
f"b_coeffs device: {b_coeffs.device}"
)
n_batch, n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(1)
assert n_order > 0
if n_order <= 0:
raise ValueError(f"Expected n_order to be positive. Found: {n_order}")
# Pad the input and create output
......@@ -983,7 +994,8 @@ def _lfilter_core(
try:
_lfilter = torch.ops.torchaudio._lfilter
except RuntimeError as err:
assert str(err) == "No such operator torchaudio::_lfilter"
if str(err) != "No such operator torchaudio::_lfilter":
raise
_lfilter = _lfilter_core
......@@ -1018,13 +1030,23 @@ def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool =
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or `(..., time)` otherwise.
"""
assert a_coeffs.size() == b_coeffs.size()
assert a_coeffs.ndim <= 2
if a_coeffs.size() != b_coeffs.size():
raise ValueError(
"Expected coeffs to be the same size."
f"Found: a_coeffs size: {a_coeffs.size()}, b_coeffs size: {b_coeffs.size()}"
)
if a_coeffs.ndim > 2:
raise ValueError(f"Expected coeffs to have greater than 1 dimension. Found: {a_coeffs.ndim}")
if a_coeffs.ndim > 1:
if batching:
assert waveform.ndim > 1
assert waveform.shape[-2] == a_coeffs.shape[0]
if waveform.ndim <= 0:
raise ValueError("Expected waveform to have a positive number of dimensions." f"Found: {waveform.ndim}")
if waveform.shape[-2] != a_coeffs.shape[0]:
raise ValueError(
"Expected number of batches in waveform and coeffs to be the same."
f"Found: coeffs batches: {a_coeffs.shape[0]}, waveform batches: {waveform.shape[-2]}"
)
else:
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
else:
......@@ -1090,7 +1112,8 @@ def _overdrive_core_loop_generic(
try:
_overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop
except RuntimeError as err:
assert str(err) == "No such operator torchaudio::_overdrive_core_loop"
if str(err) != "No such operator torchaudio::_overdrive_core_loop":
raise
_overdrive_core_loop_cpu = _overdrive_core_loop_generic
......@@ -1377,7 +1400,11 @@ def _measure(
boot_count: int,
) -> float:
assert spectrum.size()[-1] == noise_spectrum.size()[-1]
if spectrum.size(-1) != noise_spectrum.size(-1):
raise ValueError(
"Expected spectrum size to match noise spectrum size in final dimension."
f"Found: spectrum size: {spectrum.size()}, noise_spectrum size: {noise_spectrum.size()}"
)
samplesLen_ns = samples.size()[-1]
dft_len_ws = spectrum.size()[-1]
......@@ -1495,7 +1522,7 @@ def vad(
for when the noise level is decreasing. (Default: 0.01)
noise_reduction_amount (float, optional) Amount of noise reduction to use in
the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
measure_freq (float, optional) Frequency of the algorithms
measure_freq (float, optional) Frequency of the algorithm's
processing/measurements. (Default: 20.0)
measure_duration: (float, optional) Measurement duration.
(Default: Twice the measurement period; i.e. with overlap.)
......@@ -1565,7 +1592,11 @@ def vad(
cepstrum_end = math.floor(sample_rate * 0.5 / hp_lifter_freq)
cepstrum_end = min(cepstrum_end, dft_len_ws // 4)
assert cepstrum_end > cepstrum_start
if cepstrum_end <= cepstrum_start:
raise ValueError(
"Expected cepstrum_start to be smaller than cepstrum_end."
f"Found: cepstrum_start: {cepstrum_start}, cepstrum_end: {cepstrum_end}."
)
noise_up_time_mult = math.exp(-1.0 / (noise_up_time * measure_freq))
noise_down_time_mult = math.exp(-1.0 / (noise_down_time * measure_freq))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment