Unverified Commit 595b37b6 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor scripting in test (#1727)

Introduce a helper function `torch_script` that performs scripting in the recommended way.
parent ef7255bb
......@@ -14,6 +14,7 @@ from torchaudio_unittest.common_utils import (
save_wav,
load_wav,
sox_utils,
torch_script,
)
from .common import (
name_func,
......@@ -61,9 +62,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)
script_path = self.get_temp_path('info_func.zip')
torch.jit.script(py_info_func).save(script_path)
ts_info_func = torch.jit.load(script_path)
ts_info_func = torch_script(py_info_func)
py_info = py_info_func(audio_path)
ts_info = ts_info_func(audio_path)
......@@ -85,9 +84,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)
script_path = self.get_temp_path('load_func.zip')
torch.jit.script(py_load_func).save(script_path)
ts_load_func = torch.jit.load(script_path)
ts_load_func = torch_script(py_load_func)
py_data, py_sr = py_load_func(
audio_path, normalize=normalize, channels_first=channels_first)
......@@ -103,9 +100,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
[1, 2],
)), name_func=name_func)
def test_save_wav(self, dtype, sample_rate, num_channels):
script_path = self.get_temp_path('save_func.zip')
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)
ts_save_func = torch_script(py_save_func)
expected = get_wav_data(dtype, num_channels, normalize=False)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
......@@ -129,9 +124,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
list(range(9)),
)), name_func=name_func)
def test_save_flac(self, sample_rate, num_channels, compression_level):
script_path = self.get_temp_path('save_func.zip')
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)
ts_save_func = torch_script(py_save_func)
expected = get_wav_data('float32', num_channels)
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
......
......@@ -31,6 +31,8 @@ from .parameterized_utils import (
load_params,
nested_params
)
from .func_utils import torch_script
__all__ = [
'get_asset_path',
......@@ -57,4 +59,5 @@ __all__ = [
'save_wav',
'load_params',
'nested_params',
'torch_script',
]
import io
import torch
def torch_script(obj):
"""TorchScript the given function or Module"""
buffer = io.BytesIO()
torch.jit.save(torch.jit.script(obj), buffer)
buffer.seek(0)
return torch.jit.load(buffer)
......@@ -8,16 +8,16 @@ from .tacotron2_loss_impl import (
from torchaudio_unittest.common_utils import PytorchTestCase
class TestTacotron2LossShapeFloat32CPU(PytorchTestCase, Tacotron2LossShapeTests):
class TestTacotron2LossShapeFloat32CPU(Tacotron2LossShapeTests, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class TestTacotron2TorchsciptFloat32CPU(PytorchTestCase, Tacotron2LossTorchscriptTests):
class TestTacotron2TorchsciptFloat32CPU(Tacotron2LossTorchscriptTests, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class TestTacotron2GradcheckFloat64CPU(PytorchTestCase, Tacotron2LossGradcheckTests):
class TestTacotron2GradcheckFloat64CPU(Tacotron2LossGradcheckTests, PytorchTestCase):
dtype = torch.float64 # gradcheck needs a higher numerical accuracy
device = torch.device("cpu")
......@@ -2,10 +2,13 @@ import torch
from torch.autograd import gradcheck, gradgradcheck
from pipeline_tacotron2.loss import Tacotron2Loss
from torchaudio_unittest.common_utils import TempDirMixin
from torchaudio_unittest.common_utils import (
TestBaseMixin,
torch_script,
)
class Tacotron2LossInputMixin(TempDirMixin):
class Tacotron2LossInputMixin(TestBaseMixin):
def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300):
mel_specgram = torch.rand(
......@@ -59,9 +62,7 @@ class Tacotron2LossShapeTests(Tacotron2LossInputMixin):
class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin):
def _assert_torchscript_consistency(self, fn, tensors):
path = self.get_temp_path("func.zip")
torch.jit.script(fn).save(path)
ts_func = torch.jit.load(path)
ts_func = torch_script(fn)
output = fn(tensors[:3], tensors[3:])
ts_output = ts_func(tensors[:3], tensors[3:])
......
......@@ -6,9 +6,11 @@ import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin
from torchaudio_unittest.common_utils import (
TempDirMixin,
TestBaseMixin,
skipIfRocm,
torch_script,
)
......@@ -16,10 +18,7 @@ class Functional(TempDirMixin, TestBaseMixin):
"""Implements test for `functional` module that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
ts_func = torch_script(func)
torch.random.manual_seed(40)
output = func(tensor)
......@@ -35,10 +34,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
ts_func = torch_script(func)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
......
......@@ -3,10 +3,7 @@ import torch
from torch import Tensor
from torchaudio.models import Tacotron2
from torchaudio.models.tacotron2 import _Encoder, _Decoder
from torchaudio_unittest.common_utils import (
TestBaseMixin,
TempDirMixin,
)
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class Tacotron2InferenceWrapper(torch.nn.Module):
......@@ -29,13 +26,11 @@ class Tacotron2DecoderInferenceWrapper(torch.nn.Module):
return self.model.infer(memory, memory_lengths)
class TorchscriptConsistencyMixin(TempDirMixin):
class TorchscriptConsistencyMixin(TestBaseMixin):
r"""Mixin to provide easy access assert torchscript consistency"""
def _assert_torchscript_consistency(self, model, tensors):
path = self.get_temp_path("func.zip")
torch.jit.script(model).save(path)
ts_func = torch.jit.load(path)
ts_func = torch_script(model)
torch.random.manual_seed(40)
output = model(*tensors)
......@@ -46,7 +41,7 @@ class TorchscriptConsistencyMixin(TempDirMixin):
self.assertEqual(ts_output, output)
class Tacotron2EncoderTests(TestBaseMixin, TorchscriptConsistencyMixin):
class Tacotron2EncoderTests(TorchscriptConsistencyMixin):
def test_tacotron2_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Encoder."""
......@@ -105,7 +100,7 @@ def _get_decoder_model(n_mels=80, encoder_embedding_dim=512,
return model
class Tacotron2DecoderTests(TestBaseMixin, TorchscriptConsistencyMixin):
class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
def test_decoder_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Decoder."""
......@@ -252,7 +247,7 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
)
class Tacotron2Tests(TestBaseMixin, TorchscriptConsistencyMixin):
class Tacotron2Tests(TorchscriptConsistencyMixin):
def _get_inputs(
self, n_mels: int, n_batch: int, max_mel_specgram_length: int, max_text_length: int
......
import io
import torch
import torch.nn.functional as F
......@@ -11,6 +10,7 @@ from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
skipIfNoQengine,
skipIfNoCuda,
torch_script,
)
from parameterized import parameterized
......@@ -112,13 +112,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
ref_out, ref_len = model(waveforms, lengths)
# TODO: put this in a common method of Mixin class.
# Script
scripted = torch.jit.script(model)
buffer_ = io.BytesIO()
torch.jit.save(scripted, buffer_)
buffer_.seek(0)
scripted = torch.jit.load(buffer_)
scripted = torch_script(model)
hyp_out, hyp_len = scripted(waveforms, lengths)
......@@ -170,11 +164,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
ref_out, ref_len = quantized(waveforms, lengths)
# Script
scripted = torch.jit.script(quantized)
buffer_ = io.BytesIO()
torch.jit.save(scripted, buffer_)
buffer_.seek(0)
scripted = torch.jit.load(buffer_)
scripted = torch_script(quantized)
hyp_out, hyp_len = scripted(waveforms, lengths)
......
......@@ -6,10 +6,11 @@ from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
TorchaudioTestCase,
skipIfNoSox,
get_sinusoid,
save_wav,
torch_script,
)
from .common import (
load_params,
......@@ -44,7 +45,7 @@ class SoxEffectFileTransform(torch.nn.Module):
@skipIfNoSox
class TestTorchScript(TempDirMixin, PytorchTestCase):
class TestTorchScript(TempDirMixin, TorchaudioTestCase):
@parameterized.expand(
load_params("sox_effect_test_args.jsonl"),
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
......@@ -57,9 +58,7 @@ class TestTorchScript(TempDirMixin, PytorchTestCase):
trans = SoxEffectTensorTransform(effects, input_sr, channels_first)
path = self.get_temp_path('sox_effect.zip')
torch.jit.script(trans).save(path)
trans = torch.jit.load(path)
trans = torch_script(trans)
wav = get_sinusoid(
frequency=800, sample_rate=input_sr,
......@@ -82,10 +81,7 @@ class TestTorchScript(TempDirMixin, PytorchTestCase):
input_sr = args.get("input_sample_rate", 8000)
trans = SoxEffectFileTransform(effects, channels_first)
path = self.get_temp_path('sox_effect.zip')
torch.jit.script(trans).save(path)
trans = torch.jit.load(path)
trans = torch_script(trans)
path = self.get_temp_path('input.wav')
wav = get_sinusoid(
......
......@@ -7,20 +7,18 @@ from parameterized import parameterized
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
skipIfRocm,
TempDirMixin,
TestBaseMixin,
torch_script,
)
class Transforms(TempDirMixin, TestBaseMixin):
class Transforms(TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor, *args):
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('transform.zip')
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)
ts_transform = torch_script(transform)
output = transform(tensor, *args)
ts_output = ts_transform(tensor, *args)
......@@ -31,9 +29,7 @@ class Transforms(TempDirMixin, TestBaseMixin):
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('transform.zip')
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)
ts_transform = torch_script(transform)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment