Commit 2f5fcf4f authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Remove function input parameters from data aug functional tests (#3011)

Summary:
Passing functions as test parameters causes issues on some platforms. This PR updates the functional tests to pass functions by name instead.

Pull Request resolved: https://github.com/pytorch/audio/pull/3011

Reviewed By: mthrok

Differential Revision: D42748106

Pulled By: hwangjeff

fbshipit-source-id: 4d81dabe4aff2293bc344a457a034a2d9af024e2
parent aa760caf
...@@ -336,7 +336,7 @@ class Autograd(TestBaseMixin): ...@@ -336,7 +336,7 @@ class Autograd(TestBaseMixin):
self.assert_grad(F.apply_beamforming, (beamform_weights, specgram)) self.assert_grad(F.apply_beamforming, (beamform_weights, specgram))
@nested_params( @nested_params(
[F.convolve, F.fftconvolve], ["convolve", "fftconvolve"],
["full", "valid", "same"], ["full", "valid", "same"],
) )
def test_convolve(self, fn, mode): def test_convolve(self, fn, mode):
...@@ -344,7 +344,7 @@ class Autograd(TestBaseMixin): ...@@ -344,7 +344,7 @@ class Autograd(TestBaseMixin):
L_x, L_y = 23, 40 L_x, L_y = 23, 40
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device) x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device) y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
self.assert_grad(fn, (x, y, mode)) self.assert_grad(getattr(F, fn), (x, y, mode))
def test_add_noise(self): def test_add_noise(self):
leading_dims = (5, 2, 3) leading_dims = (5, 2, 3)
......
...@@ -409,7 +409,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -409,7 +409,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(F.apply_beamforming, (beamform_weights, specgram)) self.assert_batch_consistency(F.apply_beamforming, (beamform_weights, specgram))
@common_utils.nested_params( @common_utils.nested_params(
[F.convolve, F.fftconvolve], ["convolve", "fftconvolve"],
["full", "valid", "same"], ["full", "valid", "same"],
) )
def test_convolve(self, fn, mode): def test_convolve(self, fn, mode):
...@@ -418,6 +418,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -418,6 +418,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device) x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device) y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
fn = getattr(F, fn)
actual = fn(x, y, mode) actual = fn(x, y, mode)
expected = torch.stack( expected = torch.stack(
[ [
......
import math
import torch import torch
import torchaudio.prototype.functional as F import torchaudio.prototype.functional as F
from parameterized import param, parameterized from parameterized import param, parameterized
from torchaudio.functional import lfilter
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
from .dsp_utils import freq_ir as freq_ir_np, oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np from .dsp_utils import freq_ir as freq_ir_np, oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment