"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "27bf7fcd0e8069c623b564dd7024ea782b69dca8"
Unverified Commit 5c696b50 authored by moto's avatar moto Committed by GitHub
Browse files

Save/load TorchScript object in test (#1446)

parent 931555c1
...@@ -6,17 +6,21 @@ import torchaudio.functional as F ...@@ -6,17 +6,21 @@ import torchaudio.functional as F
from parameterized import parameterized from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
skipIfRocm, skipIfRocm,
) )
class Functional(common_utils.TestBaseMixin): class Functional(TempDirMixin, TestBaseMixin):
"""Implements test for `functinoal` modul that are performed for different devices""" """Implements test for `functinoal` modul that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False): def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype) tensor = tensor.to(device=self.device, dtype=self.dtype)
ts_func = torch.jit.script(func) path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
output = func(tensor) output = func(tensor)
ts_output = ts_func(tensor) ts_output = ts_func(tensor)
if shape_only: if shape_only:
...@@ -565,7 +569,7 @@ class Functional(common_utils.TestBaseMixin): ...@@ -565,7 +569,7 @@ class Functional(common_utils.TestBaseMixin):
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
class FunctionalComplex: class FunctionalComplex(TempDirMixin, TestBaseMixin):
complex_dtype = None complex_dtype = None
real_dtype = None real_dtype = None
device = None device = None
...@@ -573,7 +577,10 @@ class FunctionalComplex: ...@@ -573,7 +577,10 @@ class FunctionalComplex:
def _assert_consistency(self, func, tensor, test_pseudo_complex=False): def _assert_consistency(self, func, tensor, test_pseudo_complex=False):
assert tensor.is_complex() assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype) tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch.jit.script(func)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
if test_pseudo_complex: if test_pseudo_complex:
tensor = torch.view_as_real(tensor) tensor = torch.view_as_real(tensor)
......
...@@ -7,16 +7,21 @@ from parameterized import parameterized ...@@ -7,16 +7,21 @@ from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
skipIfRocm, skipIfRocm,
TempDirMixin,
TestBaseMixin,
) )
class Transforms(common_utils.TestBaseMixin): class Transforms(TempDirMixin, TestBaseMixin):
"""Implements test for Transforms that are performed for different devices""" """Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor): def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype) tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype) transform = transform.to(device=self.device, dtype=self.dtype)
ts_transform = torch.jit.script(transform) path = self.get_temp_path('transform.zip')
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)
output = transform(tensor) output = transform(tensor)
ts_output = ts_transform(tensor) ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
...@@ -39,8 +44,8 @@ class Transforms(common_utils.TestBaseMixin): ...@@ -39,8 +44,8 @@ class Transforms(common_utils.TestBaseMixin):
self._assert_consistency(T.AmplitudeToDB(), spec) self._assert_consistency(T.AmplitudeToDB(), spec)
def test_MelScale(self): def test_MelScale(self):
spec_f = torch.rand((1, 6, 201)) spec_f = torch.rand((1, 201, 6))
self._assert_consistency(T.MelScale(), spec_f) self._assert_consistency(T.MelScale(n_stft=201), spec_f)
def test_MelSpectrogram(self): def test_MelSpectrogram(self):
tensor = torch.rand((1, 1000)) tensor = torch.rand((1, 1000))
...@@ -100,7 +105,7 @@ class Transforms(common_utils.TestBaseMixin): ...@@ -100,7 +105,7 @@ class Transforms(common_utils.TestBaseMixin):
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform) self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)
class TransformsComplex: class TransformsComplex(TempDirMixin, TestBaseMixin):
complex_dtype = None complex_dtype = None
real_dtype = None real_dtype = None
device = None device = None
...@@ -109,7 +114,10 @@ class TransformsComplex: ...@@ -109,7 +114,10 @@ class TransformsComplex:
assert tensor.is_complex() assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype) tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.real_dtype) transform = transform.to(device=self.device, dtype=self.real_dtype)
ts_transform = torch.jit.script(transform)
path = self.get_temp_path('transform.zip')
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)
if test_pseudo_complex: if test_pseudo_complex:
tensor = torch.view_as_real(tensor) tensor = torch.view_as_real(tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment