Commit 8ba323bb authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add torchscript test to oscillator_bank (#2864)

Summary:
Missing from https://github.com/pytorch/audio/issues/2848

Pull Request resolved: https://github.com/pytorch/audio/pull/2864

Reviewed By: carolineechen

Differential Revision: D41413381

Pulled By: mthrok

fbshipit-source-id: 4377ed4a59504c6ade9ee6f42938a2bc3f04fb73
parent 92b6847e
...@@ -58,3 +58,10 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin): ...@@ -58,3 +58,10 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
n_barks = 10 n_barks = 10
sample_rate = 16000 sample_rate = 16000
self._assert_consistency(F.barkscale_fbanks, (n_stft, f_min, f_max, n_barks, sample_rate, "traunmuller")) self._assert_consistency(F.barkscale_fbanks, (n_stft, f_min, f_max, n_barks, sample_rate, "traunmuller"))
def test_oscillator_bank(self):
num_frames, num_pitches, sample_rate = 8000, 8, 8000
freq = torch.rand((num_frames, num_pitches), dtype=self.dtype, device=self.device)
amps = torch.ones_like(freq)
self._assert_consistency(F.oscillator_bank, (freq, amps, sample_rate, "sum"))
...@@ -71,7 +71,7 @@ def oscillator_bank( ...@@ -71,7 +71,7 @@ def oscillator_bank(
# We might add angular_cumsum if it turned out to be undesirable. # We might add angular_cumsum if it turned out to be undesirable.
pi2 = 2.0 * torch.pi pi2 = 2.0 * torch.pi
freqs = frequencies * pi2 / sample_rate % pi2 freqs = frequencies * pi2 / sample_rate % pi2
phases = torch.cumsum(freqs, axis=-2) phases = torch.cumsum(freqs, dim=-2)
waveform = amplitudes * torch.sin(phases) waveform = amplitudes * torch.sin(phases)
if reduction == "sum": if reduction == "sum":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment