"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "829f6defa44f93b84416aba36e5f43952c2df8bc"
Commit 9edce71c authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Merge pull request #16 from dhpollack/mu_law_companding

mu-law companding transform
parents 9538c65f 38de2cb6
...@@ -93,6 +93,34 @@ class Tester(unittest.TestCase): ...@@ -93,6 +93,34 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(0) == length_new) self.assertTrue(result.size(0) == length_new)
def test_mu_law_companding(self):
sig = self.sig.clone()
quantization_channels = 256
sig = self.sig.numpy()
sig = sig / np.abs(sig).max()
self.assertTrue(sig.min() >= -1. and sig.max() <= 1.)
sig_mu = transforms.MuLawEncoding(quantization_channels)(sig)
self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels)
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
#diff = sig - sig_exp
#mse = np.linalg.norm(diff) / diff.shape[0]
#self.assertTrue(mse, np.isclose(mse, 0., atol=1e-4)) # not always true
sig = self.sig.clone()
sig = sig / torch.abs(sig).max()
self.assertTrue(sig.min() >= -1. and sig.max() <= 1.)
sig_mu = transforms.MuLawEncoding(quantization_channels)(sig)
self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels)
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -33,7 +33,7 @@ class Scale(object): ...@@ -33,7 +33,7 @@ class Scale(object):
called the "bit depth" or "precision", not to be confused with "bit rate". called the "bit depth" or "precision", not to be confused with "bit rate".
Args: Args:
factor (float): maximum value of input tensor. default: 16-bit depth factor (int): maximum value of input tensor. default: 16-bit depth
""" """
...@@ -58,6 +58,10 @@ class Scale(object): ...@@ -58,6 +58,10 @@ class Scale(object):
class PadTrim(object): class PadTrim(object):
"""Pad/Trim a 1d-Tensor (Signal or Labels) """Pad/Trim a 1d-Tensor (Signal or Labels)
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
max_len (int): Length to which the tensor will be padded
""" """
def __init__(self, max_len, fill_value=0): def __init__(self, max_len, fill_value=0):
...@@ -67,10 +71,6 @@ class PadTrim(object): ...@@ -67,10 +71,6 @@ class PadTrim(object):
def __call__(self, tensor): def __call__(self, tensor):
""" """
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
max_len (int): Length to which the tensor will be padded
Returns: Returns:
Tensor: (max_len x Channels) Tensor: (max_len x Channels)
...@@ -88,21 +88,18 @@ class PadTrim(object): ...@@ -88,21 +88,18 @@ class PadTrim(object):
class DownmixMono(object): class DownmixMono(object):
"""Downmix any stereo signals to mono """Downmix any stereo signals to mono
Inputs:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
tensor (Tensor) (Samples x 1):
""" """
def __init__(self): def __init__(self):
pass pass
def __call__(self, tensor): def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
Tensor: (Samples x 1)
"""
if isinstance(tensor, (torch.LongTensor, torch.IntTensor)): if isinstance(tensor, (torch.LongTensor, torch.IntTensor)):
tensor = tensor.float() tensor = tensor.float()
...@@ -181,3 +178,77 @@ class BLC2CBL(object): ...@@ -181,3 +178,77 @@ class BLC2CBL(object):
""" """
return tensor.permute(2, 0, 1).contiguous() return tensor.permute(2, 0, 1).contiguous()
class MuLawEncoding(object):
"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This algorithm assumes the signal has been scaled to between -1 and 1 and
returns a signal encoded with values from 0 to quantization_channels - 1
Args:
quantization_channels (int): Number of channels. default: 256
"""
def __init__(self, quantization_channels=256):
self.qc = quantization_channels
def __call__(self, x):
"""
Args:
x (FloatTensor/LongTensor or ndarray)
Returns:
x_mu (LongTensor or ndarray)
"""
mu = self.qc - 1.
if isinstance(x, np.ndarray):
x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
elif isinstance(x, (torch.Tensor, torch.LongTensor)):
if isinstance(x, torch.LongTensor):
x = x.float()
mu = torch.FloatTensor([mu])
x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
return x_mu
class MuLawExpanding(object):
"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This expects an input with values between 0 and quantization_channels - 1
and returns a signal scaled between -1 and 1.
Args:
quantization_channels (int): Number of channels. default: 256
"""
def __init__(self, quantization_channels=256):
self.qc = quantization_channels
def __call__(self, x_mu):
"""
Args:
x_mu (FloatTensor/LongTensor or ndarray)
Returns:
x (FloatTensor or ndarray)
"""
mu = self.qc - 1.
if isinstance(x_mu, np.ndarray):
x = ((x_mu) / mu) * 2 - 1.
x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
elif isinstance(x_mu, (torch.Tensor, torch.LongTensor)):
if isinstance(x_mu, torch.LongTensor):
x_mu = x_mu.float()
mu = torch.FloatTensor([mu])
x = ((x_mu) / mu) * 2 - 1.
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment