Commit 289f08af authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

more (#160)

parent 6a43e9eb
...@@ -30,28 +30,6 @@ class Test_JIT(unittest.TestCase): ...@@ -30,28 +30,6 @@ class Test_JIT(unittest.TestCase):
self.assertTrue(torch.allclose(jit_out, py_out)) self.assertTrue(torch.allclose(jit_out, py_out))
def test_torchscript_pad_trim(self):
@torch.jit.script
def jit_method(tensor, max_len, fill_value):
# type: (Tensor, int, float) -> Tensor
return F.pad_trim(tensor, max_len, fill_value)
tensor = torch.rand((1, 10))
max_len = 5
fill_value = 3.
jit_out = jit_method(tensor, max_len, fill_value)
py_out = F.pad_trim(tensor, max_len, fill_value)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_pad_trim(self):
tensor = torch.rand((1, 10), device="cuda")
max_len = 5
self._test_script_module(tensor, transforms.PadTrim, max_len)
def test_torchscript_spectrogram(self): def test_torchscript_spectrogram(self):
@torch.jit.script @torch.jit.script
def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize): def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize):
......
...@@ -36,20 +36,6 @@ class Tester(unittest.TestCase): ...@@ -36,20 +36,6 @@ class Tester(unittest.TestCase):
waveform = waveform.to(torch.get_default_dtype()) waveform = waveform.to(torch.get_default_dtype())
return waveform / factor return waveform / factor
def test_pad_trim(self):
waveform = self.waveform.clone()
length_orig = waveform.size(1)
length_new = int(length_orig * 1.2)
result = transforms.PadTrim(max_len=length_new)(waveform)
self.assertEqual(result.size(1), length_new)
length_new = int(length_orig * 0.8)
result = transforms.PadTrim(max_len=length_new)(waveform)
self.assertEqual(result.size(1), length_new)
def test_mu_law_companding(self): def test_mu_law_companding(self):
quantization_channels = 256 quantization_channels = 256
......
...@@ -3,7 +3,6 @@ import torch ...@@ -3,7 +3,6 @@ import torch
__all__ = [ __all__ = [
'pad_trim',
'istft', 'istft',
'spectrogram', 'spectrogram',
'create_fb_matrix', 'create_fb_matrix',
...@@ -18,28 +17,6 @@ __all__ = [ ...@@ -18,28 +17,6 @@ __all__ = [
] ]
@torch.jit.script
def pad_trim(waveform, max_len, fill_value):
# type: (Tensor, int, float) -> Tensor
r"""Pad/trim a 2D tensor
Args:
waveform (torch.Tensor): Tensor of audio of size (c, n)
max_len (int): Length to which the waveform will be padded
fill_value (float): Value to fill in
Returns:
torch.Tensor: Padded/trimmed tensor
"""
n = waveform.size(1)
if max_len > n:
# TODO add "with torch.no_grad():" back when JIT supports it
waveform = torch.nn.functional.pad(waveform, (0, max_len - n), 'constant', fill_value)
else:
waveform = waveform[:, :max_len]
return waveform
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
@torch.jit.ignore @torch.jit.ignore
def _stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): def _stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
......
...@@ -7,32 +7,6 @@ from . import functional as F ...@@ -7,32 +7,6 @@ from . import functional as F
from .compliance import kaldi from .compliance import kaldi
class PadTrim(torch.jit.ScriptModule):
r"""Pad/Trim a 2D tensor
Args:
max_len (int): Length to which the waveform will be padded
fill_value (float): Value to fill in
"""
__constants__ = ['max_len', 'fill_value']
def __init__(self, max_len, fill_value=0.):
super(PadTrim, self).__init__()
self.max_len = max_len
self.fill_value = fill_value
@torch.jit.script_method
def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): Tensor of audio of size (c, n)
Returns:
Tensor: Tensor of size (c, `max_len`)
"""
return F.pad_trim(waveform, self.max_len, self.fill_value)
class Spectrogram(torch.jit.ScriptModule): class Spectrogram(torch.jit.ScriptModule):
r"""Create a spectrogram from a audio signal r"""Create a spectrogram from a audio signal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment