Unverified Commit ab733e7b authored by jimchen90's avatar jimchen90 Committed by GitHub
Browse files

update wav2letter test (#722)


Co-authored-by: default avatarJi Chen <jimchen90@devfair0160.h2.fair>
parent e9f19c35
import pytest
import torch import torch
from torchaudio.models import Wav2Letter from torchaudio.models import Wav2Letter
class TestWav2Letter: class TestWav2Letter:
@pytest.mark.parametrize('batch_size', [2])
@pytest.mark.parametrize('num_features', [1]) def test_waveform(self):
@pytest.mark.parametrize('num_classes', [40]) batch_size = 2
@pytest.mark.parametrize('input_length', [320]) num_features = 1
def test_waveform(self, batch_size, num_features, num_classes, input_length): num_classes = 40
model = Wav2Letter() input_length = 320
model = Wav2Letter(num_classes=num_classes, num_features=num_features)
x = torch.rand(batch_size, num_features, input_length) x = torch.rand(batch_size, num_features, input_length)
out = model(x) out = model(x)
assert out.size() == (batch_size, num_classes, 2) assert out.size() == (batch_size, num_classes, 2)
@pytest.mark.parametrize('batch_size', [2]) def test_mfcc(self):
@pytest.mark.parametrize('num_features', [13]) batch_size = 2
@pytest.mark.parametrize('num_classes', [40]) num_features = 13
@pytest.mark.parametrize('input_length', [2]) num_classes = 40
def test_mfcc(self, batch_size, num_features, num_classes, input_length): input_length = 2
model = Wav2Letter(input_type="mfcc", num_features=13)
model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
x = torch.rand(batch_size, num_features, input_length) x = torch.rand(batch_size, num_features, input_length)
out = model(x) out = model(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment