Commit 1923be04 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Use deterministic algorithms for filtfilt autograd tests (#3150)

Summary:
`filtfilt` function uses `lfilter`, which calls `conv_1d` operation internally. `conv_1d` is expected to have autograd test failures (see https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html). The PR uses deterministic algorithms in the autograd tests to make `filtfilt` related tests pass.

Pull Request resolved: https://github.com/pytorch/audio/pull/3150

Reviewed By: mthrok

Differential Revision: D43872977

Pulled By: nateanl

fbshipit-source-id: c3d6ec281f34db8a7092526ccb245797bf2338da
parent 67a49f3c
from .autograd_utils import use_deterministic_algorithms
from .backend_utils import set_audio_backend
from .case_utils import (
HttpServerMixin,
......@@ -66,5 +67,6 @@ __all__ = [
"get_image",
"rgb_to_gray",
"rgb_to_yuv_ccir",
"use_deterministic_algorithms",
"zip_equal",
]
import contextlib
import torch
@contextlib.contextmanager
def use_deterministic_algorithms(mode: bool, warn_only: bool):
r"""
This context manager can be used to temporarily enable or disable deterministic algorithms.
Upon exiting the context manager, the previous state of the flag will be restored.
"""
previous_mode: bool = torch.are_deterministic_algorithms_enabled()
previous_warn_only: bool = torch.is_deterministic_algorithms_warn_only_enabled()
try:
torch.use_deterministic_algorithms(mode, warn_only=warn_only)
yield {}
except RuntimeError as err:
raise err
finally:
torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only)
......@@ -6,7 +6,14 @@ import torchaudio.functional as F
from parameterized import parameterized
from torch import Tensor
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, rnnt_utils, TestBaseMixin
from torchaudio_unittest.common_utils import (
get_spectrogram,
get_whitenoise,
nested_params,
rnnt_utils,
TestBaseMixin,
use_deterministic_algorithms,
)
class Autograd(TestBaseMixin):
......@@ -71,26 +78,30 @@ class Autograd(TestBaseMixin):
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
a.requires_grad = True
self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
with use_deterministic_algorithms(True, False):
self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
def test_filtfilt_b(self):
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
b.requires_grad = True
self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
with use_deterministic_algorithms(True, False):
self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
def test_filtfilt_all_inputs(self):
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
self.assert_grad(F.filtfilt, (x, a, b))
with use_deterministic_algorithms(True, False):
self.assert_grad(F.filtfilt, (x, a, b))
def test_filtfilt_batching(self):
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
self.assert_grad(F.filtfilt, (x, a, b))
with use_deterministic_algorithms(True, False):
self.assert_grad(F.filtfilt, (x, a, b))
def test_biquad(self):
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment