Commit 6a43e9eb authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Renaming MuLawExpanding to MuLawDecoding (#159)

parent b29a4639
......@@ -170,25 +170,25 @@ class Test_JIT(unittest.TestCase):
self._test_script_module(tensor, transforms.MuLawEncoding)
def test_torchscript_mu_law_expanding(self):
def test_torchscript_mu_law_decoding(self):
@torch.jit.script
def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor
return F.mu_law_expanding(tensor, qc)
return F.mu_law_decoding(tensor, qc)
tensor = torch.rand((1, 10))
qc = 256
jit_out = jit_method(tensor, qc)
py_out = F.mu_law_expanding(tensor, qc)
py_out = F.mu_law_decoding(tensor, qc)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawExpanding(self):
def test_scriptmodule_MuLawDecoding(self):
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawExpanding)
self._test_script_module(tensor, transforms.MuLawDecoding)
if __name__ == '__main__':
......
......@@ -61,7 +61,7 @@ class Tester(unittest.TestCase):
waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
waveform_exp = transforms.MuLawExpanding(quantization_channels)(waveform_mu)
waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
def test_mel2(self):
......
......@@ -10,7 +10,7 @@ __all__ = [
'spectrogram_to_DB',
'create_dct',
'mu_law_encoding',
'mu_law_expanding',
'mu_law_decoding',
'complex_norm',
'angle',
'magphase',
......@@ -353,7 +353,7 @@ def mu_law_encoding(x, quantization_channels):
@torch.jit.script
def mu_law_expanding(x_mu, quantization_channels):
def mu_law_decoding(x_mu, quantization_channels):
# type: (Tensor, int) -> Tensor
r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......
......@@ -321,7 +321,7 @@ class MuLawEncoding(torch.jit.ScriptModule):
return F.mu_law_encoding(x, self.quantization_channels)
class MuLawExpanding(torch.jit.ScriptModule):
class MuLawDecoding(torch.jit.ScriptModule):
r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......@@ -334,7 +334,7 @@ class MuLawExpanding(torch.jit.ScriptModule):
__constants__ = ['quantization_channels']
def __init__(self, quantization_channels=256):
super(MuLawExpanding, self).__init__()
super(MuLawDecoding, self).__init__()
self.quantization_channels = quantization_channels
@torch.jit.script_method
......@@ -346,7 +346,7 @@ class MuLawExpanding(torch.jit.ScriptModule):
Returns:
torch.Tensor: The signal decoded
"""
return F.mu_law_expanding(x_mu, self.quantization_channels)
return F.mu_law_decoding(x_mu, self.quantization_channels)
class Resample(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment