Unverified Commit ab733e7b authored by jimchen90's avatar jimchen90 Committed by GitHub
Browse files

update wav2letter test (#722)


Co-authored-by: default avatarJi Chen <jimchen90@devfair0160.h2.fair>
parent e9f19c35
import pytest
import torch
from torchaudio.models import Wav2Letter
class TestWav2Letter:
@pytest.mark.parametrize('batch_size', [2])
@pytest.mark.parametrize('num_features', [1])
@pytest.mark.parametrize('num_classes', [40])
@pytest.mark.parametrize('input_length', [320])
def test_waveform(self, batch_size, num_features, num_classes, input_length):
model = Wav2Letter()
def test_waveform(self):
batch_size = 2
num_features = 1
num_classes = 40
input_length = 320
model = Wav2Letter(num_classes=num_classes, num_features=num_features)
x = torch.rand(batch_size, num_features, input_length)
out = model(x)
assert out.size() == (batch_size, num_classes, 2)
@pytest.mark.parametrize('batch_size', [2])
@pytest.mark.parametrize('num_features', [13])
@pytest.mark.parametrize('num_classes', [40])
@pytest.mark.parametrize('input_length', [2])
def test_mfcc(self, batch_size, num_features, num_classes, input_length):
model = Wav2Letter(input_type="mfcc", num_features=13)
def test_mfcc(self):
batch_size = 2
num_features = 13
num_classes = 40
input_length = 2
model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
x = torch.rand(batch_size, num_features, input_length)
out = model(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment