Commit 5428e283 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fix filtering function fallback mechanism (#2953)

Summary:
lfilter, overdrive have faster implementation written in C++. If they are not available, torchaudio is supposed to fall back on Python-based implementation.

The original fallback mechanism relied on error type and messages from PyTorch core, which has been changed.

This commit updates it for more proper fallback mechanism.

Pull Request resolved: https://github.com/pytorch/audio/pull/2953

Reviewed By: hwangjeff

Differential Revision: D42344893

Pulled By: mthrok

fbshipit-source-id: 18ce5c1aa1c69d0d2ab469b0b0c36c0221f5ccfd
parent f70b970a
...@@ -5,6 +5,8 @@ from typing import Optional ...@@ -5,6 +5,8 @@ from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE
def _dB2Linear(x: float) -> float: def _dB2Linear(x: float) -> float:
return math.exp(x * math.log(10) / 20.0) return math.exp(x * math.log(10) / 20.0)
...@@ -929,11 +931,9 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T ...@@ -929,11 +931,9 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T
padded_output_waveform[:, :, i_sample + n_order - 1] = o0 padded_output_waveform[:, :, i_sample + n_order - 1] = o0
try: if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop _lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
except RuntimeError as err: else:
if str(err) != "No such operator torchaudio::_lfilter_core_loop":
raise
_lfilter_core_cpu_loop = _lfilter_core_generic_loop _lfilter_core_cpu_loop = _lfilter_core_generic_loop
...@@ -991,11 +991,9 @@ def _lfilter_core( ...@@ -991,11 +991,9 @@ def _lfilter_core(
return output return output
try: if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter = torch.ops.torchaudio._lfilter _lfilter = torch.ops.torchaudio._lfilter
except RuntimeError as err: else:
if str(err) != "No such operator torchaudio::_lfilter":
raise
_lfilter = _lfilter_core _lfilter = _lfilter_core
...@@ -1109,11 +1107,9 @@ def _overdrive_core_loop_generic( ...@@ -1109,11 +1107,9 @@ def _overdrive_core_loop_generic(
output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75 output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75
try: if _IS_TORCHAUDIO_EXT_AVAILABLE:
_overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop _overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop
except RuntimeError as err: else:
if str(err) != "No such operator torchaudio::_overdrive_core_loop":
raise
_overdrive_core_loop_cpu = _overdrive_core_loop_generic _overdrive_core_loop_cpu = _overdrive_core_loop_generic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment