Commit 691317a9 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add tacotron2 unittest with different batch_size (#2176)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2176

Reviewed By: carolineechen, mthrok

Differential Revision: D33794216

Pulled By: nateanl

fbshipit-source-id: e039c1fc03a89f1e8130a5c4dbc4beceff4081eb
parent 0d6d0669
from typing import Tuple from typing import Tuple
import torch import torch
from parameterized import parameterized
from torch import Tensor from torch import Tensor
from torchaudio.models import Tacotron2 from torchaudio.models import Tacotron2
from torchaudio.models.tacotron2 import _Encoder, _Decoder from torchaudio.models.tacotron2 import _Encoder, _Decoder
...@@ -94,9 +95,14 @@ def _get_decoder_model(n_mels=80, encoder_embedding_dim=512, decoder_max_step=20 ...@@ -94,9 +95,14 @@ def _get_decoder_model(n_mels=80, encoder_embedding_dim=512, decoder_max_step=20
class Tacotron2DecoderTests(TorchscriptConsistencyMixin): class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
def test_decoder_torchscript_consistency(self): @parameterized.expand(
[
(1,),
(16,),
]
)
def test_decoder_torchscript_consistency(self, n_batch):
r"""Validate the torchscript consistency of a Decoder.""" r"""Validate the torchscript consistency of a Decoder."""
n_batch = 16
n_mels = 80 n_mels = 80
n_seq = 200 n_seq = 200
encoder_embedding_dim = 256 encoder_embedding_dim = 256
...@@ -111,11 +117,16 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin): ...@@ -111,11 +117,16 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
self._assert_torchscript_consistency(model, (memory, decoder_inputs, memory_lengths)) self._assert_torchscript_consistency(model, (memory, decoder_inputs, memory_lengths))
def test_decoder_output_shape(self): @parameterized.expand(
[
(1,),
(16,),
]
)
def test_decoder_output_shape(self, n_batch):
r"""Feed tensors with specific shape to Tacotron2 Decoder and validate r"""Feed tensors with specific shape to Tacotron2 Decoder and validate
that it outputs with a tensor with expected shape. that it outputs with a tensor with expected shape.
""" """
n_batch = 16
n_mels = 80 n_mels = 80
n_seq = 200 n_seq = 200
encoder_embedding_dim = 256 encoder_embedding_dim = 256
...@@ -134,9 +145,14 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin): ...@@ -134,9 +145,14 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
assert gate_outputs.size() == (n_batch, n_time_steps) assert gate_outputs.size() == (n_batch, n_time_steps)
assert alignments.size() == (n_batch, n_time_steps, n_seq) assert alignments.size() == (n_batch, n_time_steps, n_seq)
def test_decoder_inference_torchscript_consistency(self): @parameterized.expand(
[
(1,),
(16,),
]
)
def test_decoder_inference_torchscript_consistency(self, n_batch):
r"""Validate the torchscript consistency of a Decoder.""" r"""Validate the torchscript consistency of a Decoder."""
n_batch = 16
n_mels = 80 n_mels = 80
n_seq = 200 n_seq = 200
encoder_embedding_dim = 256 encoder_embedding_dim = 256
...@@ -158,9 +174,14 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin): ...@@ -158,9 +174,14 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
self._assert_torchscript_consistency(model_wrapper, (memory, memory_lengths)) self._assert_torchscript_consistency(model_wrapper, (memory, memory_lengths))
def test_decoder_inference_output_shape(self): @parameterized.expand(
[
(1,),
(16,),
]
)
def test_decoder_inference_output_shape(self, n_batch):
r"""Validate the torchscript consistency of a Decoder.""" r"""Validate the torchscript consistency of a Decoder."""
n_batch = 16
n_mels = 80 n_mels = 80
n_seq = 200 n_seq = 200
encoder_embedding_dim = 256 encoder_embedding_dim = 256
...@@ -238,9 +259,14 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -238,9 +259,14 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,), dtype=torch.int32, device=self.device) mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,), dtype=torch.int32, device=self.device)
return text, text_lengths, mel_specgram, mel_specgram_lengths return text, text_lengths, mel_specgram, mel_specgram_lengths
def test_tacotron2_torchscript_consistency(self): @parameterized.expand(
[
(1,),
(16,),
]
)
def test_tacotron2_torchscript_consistency(self, n_batch):
r"""Validate the torchscript consistency of a Tacotron2.""" r"""Validate the torchscript consistency of a Tacotron2."""
n_batch = 16
n_mels = 80 n_mels = 80
max_mel_specgram_length = 300 max_mel_specgram_length = 300
max_text_length = 100 max_text_length = 100
...@@ -250,11 +276,16 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -250,11 +276,16 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
self._assert_torchscript_consistency(model, inputs) self._assert_torchscript_consistency(model, inputs)
def test_tacotron2_output_shape(self): @parameterized.expand(
[
(1,),
(16,),
]
)
def test_tacotron2_output_shape(self, n_batch):
r"""Feed tensors with specific shape to Tacotron2 and validate r"""Feed tensors with specific shape to Tacotron2 and validate
that it outputs with a tensor with expected shape. that it outputs with a tensor with expected shape.
""" """
n_batch = 16
n_mels = 80 n_mels = 80
max_mel_specgram_length = 300 max_mel_specgram_length = 300
max_text_length = 100 max_text_length = 100
...@@ -268,12 +299,17 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -268,12 +299,17 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
assert gate_outputs.size() == (n_batch, max_mel_specgram_length) assert gate_outputs.size() == (n_batch, max_mel_specgram_length)
assert alignments.size() == (n_batch, max_mel_specgram_length, max_text_length) assert alignments.size() == (n_batch, max_mel_specgram_length, max_text_length)
def test_tacotron2_backward(self): @parameterized.expand(
[
(1,),
(16,),
]
)
def test_tacotron2_backward(self, n_batch):
r"""Make sure calling the backward function on Tacotron2's outputs does r"""Make sure calling the backward function on Tacotron2's outputs does
not error out. Following: not error out. Following:
https://github.com/pytorch/vision/blob/23b8760374a5aaed53c6e5fc83a7e83dbe3b85df/test/test_models.py#L255 https://github.com/pytorch/vision/blob/23b8760374a5aaed53c6e5fc83a7e83dbe3b85df/test/test_models.py#L255
""" """
n_batch = 16
n_mels = 80 n_mels = 80
max_mel_specgram_length = 300 max_mel_specgram_length = 300
max_text_length = 100 max_text_length = 100
...@@ -291,9 +327,14 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -291,9 +327,14 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
text_lengths = max_text_length * torch.ones((n_batch,), dtype=torch.int32, device=self.device) text_lengths = max_text_length * torch.ones((n_batch,), dtype=torch.int32, device=self.device)
return text, text_lengths return text, text_lengths
def test_tacotron2_inference_torchscript_consistency(self): @parameterized.expand(
[
(1,),
(16,),
]
)
def test_tacotron2_inference_torchscript_consistency(self, n_batch):
r"""Validate the torchscript consistency of Tacotron2 inference function.""" r"""Validate the torchscript consistency of Tacotron2 inference function."""
n_batch = 16
n_mels = 40 n_mels = 40
max_text_length = 100 max_text_length = 100
decoder_max_step = 200 # make inference more efficient decoder_max_step = 200 # make inference more efficient
...@@ -310,11 +351,16 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -310,11 +351,16 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
self._assert_torchscript_consistency(model_wrapper, inputs) self._assert_torchscript_consistency(model_wrapper, inputs)
def test_tacotron2_inference_output_shape(self): @parameterized.expand(
[
(1,),
(16,),
]
)
def test_tacotron2_inference_output_shape(self, n_batch):
r"""Feed tensors with specific shape to Tacotron2 inference function and validate r"""Feed tensors with specific shape to Tacotron2 inference function and validate
that it outputs with a tensor with expected shape. that it outputs with a tensor with expected shape.
""" """
n_batch = 16
n_mels = 40 n_mels = 40
max_text_length = 100 max_text_length = 100
decoder_max_step = 200 # make inference more efficient decoder_max_step = 200 # make inference more efficient
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment