Commit f2b2f05a authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Revise VGGish pipeline test again (#3551)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3551

Restores VGGish pipeline test to be a function rather than class.

Reviewed By: mthrok

Differential Revision: D48236197

fbshipit-source-id: 25ac19d87a7a0964a9c3f7552037cd6c21dc38a9
parent 9bd7ca51
import unittest
import torchaudio
from torchaudio.prototype.pipelines import VGGISH
class VGGishPipelineTest(unittest.TestCase):
def test_vggish(self):
input_sr = VGGISH.sample_rate
input_proc = VGGISH.get_input_processor()
model = VGGISH.get_model()
path = torchaudio.utils.download_asset("test-assets/Chopin_Ballade_-1_In_G_Minor,_Op._23_excerpt.mp3")
waveform, sr = torchaudio.load(path, backend="ffmpeg")
waveform = waveform.mean(axis=0)
waveform = torchaudio.functional.resample(waveform, sr, input_sr)
batch = input_proc(waveform)
assert batch.shape == (62, 1, 96, 64)
output = model(batch)
assert output.shape == (62, 128)
def test_vggish():
input_sr = VGGISH.sample_rate
input_proc = VGGISH.get_input_processor()
model = VGGISH.get_model()
path = torchaudio.utils.download_asset("test-assets/Chopin_Ballade_-1_In_G_Minor,_Op._23_excerpt.mp3")
waveform, sr = torchaudio.load(path, backend="ffmpeg")
waveform = waveform.mean(axis=0)
waveform = torchaudio.functional.resample(waveform, sr, input_sr)
batch = input_proc(waveform)
assert batch.shape == (62, 1, 96, 64)
output = model(batch)
assert output.shape == (62, 128)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment