"docs/cli.mdx" did not exist on "5483497d7a2156c609e7903f680302b11fee40b9"
Unverified Commit 4a3e5006 authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Merge pull request #28 from vishwakftw/transforms-docs

Add __repr__ for transforms
parents 67564173 ac0c94be
......@@ -29,6 +29,9 @@ class Tester(unittest.TestCase):
result.min() >= -1. and result.max() <= 1.,
print("min: {}, max: {}".format(result.min(), result.max())))
repr_test = transforms.Scale()
repr_test.__repr__()
def test_pad_trim(self):
audio_orig = self.sig.clone()
......@@ -49,6 +52,9 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(0) == length_new,
print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0))))
repr_test = transforms.PadTrim(max_len=length_new)
repr_test.__repr__()
def test_downmix_mono(self):
audio_L = self.sig.clone()
......@@ -64,12 +70,18 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono()
repr_test.__repr__()
def test_lc2cl(self):
audio = self.sig.clone()
result = transforms.LC2CL()(audio)
self.assertTrue(result.size()[::-1] == audio.size())
repr_test = transforms.LC2CL()
repr_test.__repr__()
def test_mel(self):
audio = self.sig.clone()
......@@ -80,6 +92,11 @@ class Tester(unittest.TestCase):
result = transforms.BLC2CBL()(result)
self.assertTrue(len(result.size()) == 3)
repr_test = transforms.MEL()
repr_test.__repr__()
repr_test = transforms.BLC2CBL()
repr_test.__repr__()
def test_compose(self):
audio_orig = self.sig.clone()
......@@ -96,6 +113,9 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(0) == length_new)
repr_test = transforms.Compose(tset)
repr_test.__repr__()
def test_mu_law_companding(self):
sig = self.sig.clone()
......@@ -121,6 +141,11 @@ class Tester(unittest.TestCase):
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
repr_test = transforms.MuLawEncoding(quantization_channels)
repr_test.__repr__()
repr_test = transforms.MuLawExpanding(quantization_channels)
repr_test.__repr__()
if __name__ == '__main__':
unittest.main()
......@@ -28,6 +28,14 @@ class Compose(object):
audio = t(audio)
return audio
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class Scale(object):
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
......@@ -57,6 +65,9 @@ class Scale(object):
return tensor / self.factor
def __repr__(self):
return self.__class__.__name__ + '()'
class PadTrim(object):
"""Pad/Trim a 1d-Tensor (Signal or Labels)
......@@ -87,6 +98,9 @@ class PadTrim(object):
tensor = tensor[:self.max_len, :]
return tensor
def __repr__(self):
return self.__class__.__name__ + '(max_len={0})'.format(self.max_len)
class DownmixMono(object):
"""Downmix any stereo signals to mono
......@@ -110,6 +124,9 @@ class DownmixMono(object):
tensor = torch.mean(tensor.float(), 1, True)
return tensor
def __repr__(self):
return self.__class__.__name__ + '()'
class LC2CL(object):
"""Permute a 2d tensor from samples (Length) x Channels to Channels x
......@@ -129,6 +146,9 @@ class LC2CL(object):
return tensor.transpose(0, 1).contiguous()
def __repr__(self):
return self.__class__.__name__ + '()'
class MEL(object):
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
......@@ -166,6 +186,9 @@ class MEL(object):
return tensor
def __repr__(self):
return self.__class__.__name__ + '()'
class BLC2CBL(object):
"""Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x
......@@ -185,6 +208,9 @@ class BLC2CBL(object):
return tensor.permute(2, 0, 1).contiguous()
def __repr__(self):
return self.__class__.__name__ + '()'
class MuLawEncoding(object):
"""Encode signal based on mu-law companding. For more info see the
......@@ -224,6 +250,9 @@ class MuLawEncoding(object):
x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
return x_mu
def __repr__(self):
return self.__class__.__name__ + '()'
class MuLawExpanding(object):
"""Decode mu-law encoded signal. For more info see the
......@@ -261,3 +290,6 @@ class MuLawExpanding(object):
x = ((x_mu) / mu) * 2 - 1.
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
return x
def __repr__(self):
return self.__class__.__name__ + '()'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment