Unverified Commit 595b37b6 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor scripting in test (#1727)

Introduce a helper function `torch_script` that performs scripting in the recommended way.
parent ef7255bb
...@@ -14,6 +14,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -14,6 +14,7 @@ from torchaudio_unittest.common_utils import (
save_wav, save_wav,
load_wav, load_wav,
sox_utils, sox_utils,
torch_script,
) )
from .common import ( from .common import (
name_func, name_func,
...@@ -61,9 +62,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -61,9 +62,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate) save_wav(audio_path, data, sample_rate)
script_path = self.get_temp_path('info_func.zip') ts_info_func = torch_script(py_info_func)
torch.jit.script(py_info_func).save(script_path)
ts_info_func = torch.jit.load(script_path)
py_info = py_info_func(audio_path) py_info = py_info_func(audio_path)
ts_info = ts_info_func(audio_path) ts_info = ts_info_func(audio_path)
...@@ -85,9 +84,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -85,9 +84,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate) save_wav(audio_path, data, sample_rate)
script_path = self.get_temp_path('load_func.zip') ts_load_func = torch_script(py_load_func)
torch.jit.script(py_load_func).save(script_path)
ts_load_func = torch.jit.load(script_path)
py_data, py_sr = py_load_func( py_data, py_sr = py_load_func(
audio_path, normalize=normalize, channels_first=channels_first) audio_path, normalize=normalize, channels_first=channels_first)
...@@ -103,9 +100,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -103,9 +100,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
[1, 2], [1, 2],
)), name_func=name_func) )), name_func=name_func)
def test_save_wav(self, dtype, sample_rate, num_channels): def test_save_wav(self, dtype, sample_rate, num_channels):
script_path = self.get_temp_path('save_func.zip') ts_save_func = torch_script(py_save_func)
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)
expected = get_wav_data(dtype, num_channels, normalize=False) expected = get_wav_data(dtype, num_channels, normalize=False)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav') py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
...@@ -129,9 +124,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -129,9 +124,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
list(range(9)), list(range(9)),
)), name_func=name_func) )), name_func=name_func)
def test_save_flac(self, sample_rate, num_channels, compression_level): def test_save_flac(self, sample_rate, num_channels, compression_level):
script_path = self.get_temp_path('save_func.zip') ts_save_func = torch_script(py_save_func)
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)
expected = get_wav_data('float32', num_channels) expected = get_wav_data('float32', num_channels)
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac') py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
......
...@@ -31,6 +31,8 @@ from .parameterized_utils import ( ...@@ -31,6 +31,8 @@ from .parameterized_utils import (
load_params, load_params,
nested_params nested_params
) )
from .func_utils import torch_script
__all__ = [ __all__ = [
'get_asset_path', 'get_asset_path',
...@@ -57,4 +59,5 @@ __all__ = [ ...@@ -57,4 +59,5 @@ __all__ = [
'save_wav', 'save_wav',
'load_params', 'load_params',
'nested_params', 'nested_params',
'torch_script',
] ]
import io
import torch
def torch_script(obj):
"""TorchScript the given function or Module"""
buffer = io.BytesIO()
torch.jit.save(torch.jit.script(obj), buffer)
buffer.seek(0)
return torch.jit.load(buffer)
...@@ -8,16 +8,16 @@ from .tacotron2_loss_impl import ( ...@@ -8,16 +8,16 @@ from .tacotron2_loss_impl import (
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
class TestTacotron2LossShapeFloat32CPU(PytorchTestCase, Tacotron2LossShapeTests): class TestTacotron2LossShapeFloat32CPU(Tacotron2LossShapeTests, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device("cpu") device = torch.device("cpu")
class TestTacotron2TorchsciptFloat32CPU(PytorchTestCase, Tacotron2LossTorchscriptTests): class TestTacotron2TorchsciptFloat32CPU(Tacotron2LossTorchscriptTests, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device("cpu") device = torch.device("cpu")
class TestTacotron2GradcheckFloat64CPU(PytorchTestCase, Tacotron2LossGradcheckTests): class TestTacotron2GradcheckFloat64CPU(Tacotron2LossGradcheckTests, PytorchTestCase):
dtype = torch.float64 # gradcheck needs a higher numerical accuracy dtype = torch.float64 # gradcheck needs a higher numerical accuracy
device = torch.device("cpu") device = torch.device("cpu")
...@@ -2,10 +2,13 @@ import torch ...@@ -2,10 +2,13 @@ import torch
from torch.autograd import gradcheck, gradgradcheck from torch.autograd import gradcheck, gradgradcheck
from pipeline_tacotron2.loss import Tacotron2Loss from pipeline_tacotron2.loss import Tacotron2Loss
from torchaudio_unittest.common_utils import TempDirMixin from torchaudio_unittest.common_utils import (
TestBaseMixin,
torch_script,
)
class Tacotron2LossInputMixin(TempDirMixin): class Tacotron2LossInputMixin(TestBaseMixin):
def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300): def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300):
mel_specgram = torch.rand( mel_specgram = torch.rand(
...@@ -59,9 +62,7 @@ class Tacotron2LossShapeTests(Tacotron2LossInputMixin): ...@@ -59,9 +62,7 @@ class Tacotron2LossShapeTests(Tacotron2LossInputMixin):
class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin): class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin):
def _assert_torchscript_consistency(self, fn, tensors): def _assert_torchscript_consistency(self, fn, tensors):
path = self.get_temp_path("func.zip") ts_func = torch_script(fn)
torch.jit.script(fn).save(path)
ts_func = torch.jit.load(path)
output = fn(tensors[:3], tensors[3:]) output = fn(tensors[:3], tensors[3:])
ts_output = ts_func(tensors[:3], tensors[3:]) ts_output = ts_func(tensors[:3], tensors[3:])
......
...@@ -6,9 +6,11 @@ import torchaudio.functional as F ...@@ -6,9 +6,11 @@ import torchaudio.functional as F
from parameterized import parameterized from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin,
TestBaseMixin,
skipIfRocm, skipIfRocm,
torch_script,
) )
...@@ -16,10 +18,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -16,10 +18,7 @@ class Functional(TempDirMixin, TestBaseMixin):
"""Implements test for `functional` module that are performed for different devices""" """Implements test for `functional` module that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False): def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype) tensor = tensor.to(device=self.device, dtype=self.dtype)
ts_func = torch_script(func)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
torch.random.manual_seed(40) torch.random.manual_seed(40)
output = func(tensor) output = func(tensor)
...@@ -35,10 +34,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -35,10 +34,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False): def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False):
assert tensor.is_complex() assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype) tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch_script(func)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
if test_pseudo_complex: if test_pseudo_complex:
tensor = torch.view_as_real(tensor) tensor = torch.view_as_real(tensor)
......
...@@ -3,10 +3,7 @@ import torch ...@@ -3,10 +3,7 @@ import torch
from torch import Tensor from torch import Tensor
from torchaudio.models import Tacotron2 from torchaudio.models import Tacotron2
from torchaudio.models.tacotron2 import _Encoder, _Decoder from torchaudio.models.tacotron2 import _Encoder, _Decoder
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
TestBaseMixin,
TempDirMixin,
)
class Tacotron2InferenceWrapper(torch.nn.Module): class Tacotron2InferenceWrapper(torch.nn.Module):
...@@ -29,13 +26,11 @@ class Tacotron2DecoderInferenceWrapper(torch.nn.Module): ...@@ -29,13 +26,11 @@ class Tacotron2DecoderInferenceWrapper(torch.nn.Module):
return self.model.infer(memory, memory_lengths) return self.model.infer(memory, memory_lengths)
class TorchscriptConsistencyMixin(TempDirMixin): class TorchscriptConsistencyMixin(TestBaseMixin):
r"""Mixin to provide easy access assert torchscript consistency""" r"""Mixin to provide easy access assert torchscript consistency"""
def _assert_torchscript_consistency(self, model, tensors): def _assert_torchscript_consistency(self, model, tensors):
path = self.get_temp_path("func.zip") ts_func = torch_script(model)
torch.jit.script(model).save(path)
ts_func = torch.jit.load(path)
torch.random.manual_seed(40) torch.random.manual_seed(40)
output = model(*tensors) output = model(*tensors)
...@@ -46,7 +41,7 @@ class TorchscriptConsistencyMixin(TempDirMixin): ...@@ -46,7 +41,7 @@ class TorchscriptConsistencyMixin(TempDirMixin):
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
class Tacotron2EncoderTests(TestBaseMixin, TorchscriptConsistencyMixin): class Tacotron2EncoderTests(TorchscriptConsistencyMixin):
def test_tacotron2_torchscript_consistency(self): def test_tacotron2_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Encoder.""" r"""Validate the torchscript consistency of a Encoder."""
...@@ -105,7 +100,7 @@ def _get_decoder_model(n_mels=80, encoder_embedding_dim=512, ...@@ -105,7 +100,7 @@ def _get_decoder_model(n_mels=80, encoder_embedding_dim=512,
return model return model
class Tacotron2DecoderTests(TestBaseMixin, TorchscriptConsistencyMixin): class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
def test_decoder_torchscript_consistency(self): def test_decoder_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Decoder.""" r"""Validate the torchscript consistency of a Decoder."""
...@@ -252,7 +247,7 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5): ...@@ -252,7 +247,7 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
) )
class Tacotron2Tests(TestBaseMixin, TorchscriptConsistencyMixin): class Tacotron2Tests(TorchscriptConsistencyMixin):
def _get_inputs( def _get_inputs(
self, n_mels: int, n_batch: int, max_mel_specgram_length: int, max_text_length: int self, n_mels: int, n_batch: int, max_mel_specgram_length: int, max_text_length: int
......
import io
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -11,6 +10,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -11,6 +10,7 @@ from torchaudio_unittest.common_utils import (
TorchaudioTestCase, TorchaudioTestCase,
skipIfNoQengine, skipIfNoQengine,
skipIfNoCuda, skipIfNoCuda,
torch_script,
) )
from parameterized import parameterized from parameterized import parameterized
...@@ -112,13 +112,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -112,13 +112,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
ref_out, ref_len = model(waveforms, lengths) ref_out, ref_len = model(waveforms, lengths)
# TODO: put this in a common method of Mixin class. scripted = torch_script(model)
# Script
scripted = torch.jit.script(model)
buffer_ = io.BytesIO()
torch.jit.save(scripted, buffer_)
buffer_.seek(0)
scripted = torch.jit.load(buffer_)
hyp_out, hyp_len = scripted(waveforms, lengths) hyp_out, hyp_len = scripted(waveforms, lengths)
...@@ -170,11 +164,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -170,11 +164,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
ref_out, ref_len = quantized(waveforms, lengths) ref_out, ref_len = quantized(waveforms, lengths)
# Script # Script
scripted = torch.jit.script(quantized) scripted = torch_script(quantized)
buffer_ = io.BytesIO()
torch.jit.save(scripted, buffer_)
buffer_.seek(0)
scripted = torch.jit.load(buffer_)
hyp_out, hyp_len = scripted(waveforms, lengths) hyp_out, hyp_len = scripted(waveforms, lengths)
......
...@@ -6,10 +6,11 @@ from parameterized import parameterized ...@@ -6,10 +6,11 @@ from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
PytorchTestCase, TorchaudioTestCase,
skipIfNoSox, skipIfNoSox,
get_sinusoid, get_sinusoid,
save_wav, save_wav,
torch_script,
) )
from .common import ( from .common import (
load_params, load_params,
...@@ -44,7 +45,7 @@ class SoxEffectFileTransform(torch.nn.Module): ...@@ -44,7 +45,7 @@ class SoxEffectFileTransform(torch.nn.Module):
@skipIfNoSox @skipIfNoSox
class TestTorchScript(TempDirMixin, PytorchTestCase): class TestTorchScript(TempDirMixin, TorchaudioTestCase):
@parameterized.expand( @parameterized.expand(
load_params("sox_effect_test_args.jsonl"), load_params("sox_effect_test_args.jsonl"),
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
...@@ -57,9 +58,7 @@ class TestTorchScript(TempDirMixin, PytorchTestCase): ...@@ -57,9 +58,7 @@ class TestTorchScript(TempDirMixin, PytorchTestCase):
trans = SoxEffectTensorTransform(effects, input_sr, channels_first) trans = SoxEffectTensorTransform(effects, input_sr, channels_first)
path = self.get_temp_path('sox_effect.zip') trans = torch_script(trans)
torch.jit.script(trans).save(path)
trans = torch.jit.load(path)
wav = get_sinusoid( wav = get_sinusoid(
frequency=800, sample_rate=input_sr, frequency=800, sample_rate=input_sr,
...@@ -82,10 +81,7 @@ class TestTorchScript(TempDirMixin, PytorchTestCase): ...@@ -82,10 +81,7 @@ class TestTorchScript(TempDirMixin, PytorchTestCase):
input_sr = args.get("input_sample_rate", 8000) input_sr = args.get("input_sample_rate", 8000)
trans = SoxEffectFileTransform(effects, channels_first) trans = SoxEffectFileTransform(effects, channels_first)
trans = torch_script(trans)
path = self.get_temp_path('sox_effect.zip')
torch.jit.script(trans).save(path)
trans = torch.jit.load(path)
path = self.get_temp_path('input.wav') path = self.get_temp_path('input.wav')
wav = get_sinusoid( wav = get_sinusoid(
......
...@@ -7,20 +7,18 @@ from parameterized import parameterized ...@@ -7,20 +7,18 @@ from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
skipIfRocm, skipIfRocm,
TempDirMixin,
TestBaseMixin, TestBaseMixin,
torch_script,
) )
class Transforms(TempDirMixin, TestBaseMixin): class Transforms(TestBaseMixin):
"""Implements test for Transforms that are performed for different devices""" """Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor, *args): def _assert_consistency(self, transform, tensor, *args):
tensor = tensor.to(device=self.device, dtype=self.dtype) tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype) transform = transform.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('transform.zip') ts_transform = torch_script(transform)
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)
output = transform(tensor, *args) output = transform(tensor, *args)
ts_output = ts_transform(tensor, *args) ts_output = ts_transform(tensor, *args)
...@@ -31,9 +29,7 @@ class Transforms(TempDirMixin, TestBaseMixin): ...@@ -31,9 +29,7 @@ class Transforms(TempDirMixin, TestBaseMixin):
tensor = tensor.to(device=self.device, dtype=self.complex_dtype) tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.dtype) transform = transform.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('transform.zip') ts_transform = torch_script(transform)
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)
if test_pseudo_complex: if test_pseudo_complex:
tensor = torch.view_as_real(tensor) tensor = torch.view_as_real(tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment