Commit 6a43e9eb authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Renaming MuLawExpanding to MuLawDecoding (#159)

parent b29a4639
...@@ -170,25 +170,25 @@ class Test_JIT(unittest.TestCase): ...@@ -170,25 +170,25 @@ class Test_JIT(unittest.TestCase):
self._test_script_module(tensor, transforms.MuLawEncoding) self._test_script_module(tensor, transforms.MuLawEncoding)
def test_torchscript_mu_law_expanding(self): def test_torchscript_mu_law_decoding(self):
@torch.jit.script @torch.jit.script
def jit_method(tensor, qc): def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
return F.mu_law_expanding(tensor, qc) return F.mu_law_decoding(tensor, qc)
tensor = torch.rand((1, 10)) tensor = torch.rand((1, 10))
qc = 256 qc = 256
jit_out = jit_method(tensor, qc) jit_out = jit_method(tensor, qc)
py_out = F.mu_law_expanding(tensor, qc) py_out = F.mu_law_decoding(tensor, qc)
self.assertTrue(torch.allclose(jit_out, py_out)) self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawExpanding(self): def test_scriptmodule_MuLawDecoding(self):
tensor = torch.rand((1, 10), device="cuda") tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawExpanding) self._test_script_module(tensor, transforms.MuLawDecoding)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -61,7 +61,7 @@ class Tester(unittest.TestCase): ...@@ -61,7 +61,7 @@ class Tester(unittest.TestCase):
waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform) waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels) self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
waveform_exp = transforms.MuLawExpanding(quantization_channels)(waveform_mu) waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
def test_mel2(self): def test_mel2(self):
......
...@@ -10,7 +10,7 @@ __all__ = [ ...@@ -10,7 +10,7 @@ __all__ = [
'spectrogram_to_DB', 'spectrogram_to_DB',
'create_dct', 'create_dct',
'mu_law_encoding', 'mu_law_encoding',
'mu_law_expanding', 'mu_law_decoding',
'complex_norm', 'complex_norm',
'angle', 'angle',
'magphase', 'magphase',
...@@ -353,7 +353,7 @@ def mu_law_encoding(x, quantization_channels): ...@@ -353,7 +353,7 @@ def mu_law_encoding(x, quantization_channels):
@torch.jit.script @torch.jit.script
def mu_law_expanding(x_mu, quantization_channels): def mu_law_decoding(x_mu, quantization_channels):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
r"""Decode mu-law encoded signal. For more info see the r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......
...@@ -321,7 +321,7 @@ class MuLawEncoding(torch.jit.ScriptModule): ...@@ -321,7 +321,7 @@ class MuLawEncoding(torch.jit.ScriptModule):
return F.mu_law_encoding(x, self.quantization_channels) return F.mu_law_encoding(x, self.quantization_channels)
class MuLawExpanding(torch.jit.ScriptModule): class MuLawDecoding(torch.jit.ScriptModule):
r"""Decode mu-law encoded signal. For more info see the r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -334,7 +334,7 @@ class MuLawExpanding(torch.jit.ScriptModule): ...@@ -334,7 +334,7 @@ class MuLawExpanding(torch.jit.ScriptModule):
__constants__ = ['quantization_channels'] __constants__ = ['quantization_channels']
def __init__(self, quantization_channels=256): def __init__(self, quantization_channels=256):
super(MuLawExpanding, self).__init__() super(MuLawDecoding, self).__init__()
self.quantization_channels = quantization_channels self.quantization_channels = quantization_channels
@torch.jit.script_method @torch.jit.script_method
...@@ -346,7 +346,7 @@ class MuLawExpanding(torch.jit.ScriptModule): ...@@ -346,7 +346,7 @@ class MuLawExpanding(torch.jit.ScriptModule):
Returns: Returns:
torch.Tensor: The signal decoded torch.Tensor: The signal decoded
""" """
return F.mu_law_expanding(x_mu, self.quantization_channels) return F.mu_law_decoding(x_mu, self.quantization_channels)
class Resample(torch.nn.Module): class Resample(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment