Commit ac0c94be authored by vishwakftw's avatar vishwakftw
Browse files

Add __repr__ for transforms and tests

parent 67564173
...@@ -29,6 +29,9 @@ class Tester(unittest.TestCase): ...@@ -29,6 +29,9 @@ class Tester(unittest.TestCase):
result.min() >= -1. and result.max() <= 1., result.min() >= -1. and result.max() <= 1.,
print("min: {}, max: {}".format(result.min(), result.max()))) print("min: {}, max: {}".format(result.min(), result.max())))
repr_test = transforms.Scale()
repr_test.__repr__()
def test_pad_trim(self): def test_pad_trim(self):
audio_orig = self.sig.clone() audio_orig = self.sig.clone()
...@@ -49,6 +52,9 @@ class Tester(unittest.TestCase): ...@@ -49,6 +52,9 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(0) == length_new, self.assertTrue(result.size(0) == length_new,
print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0)))) print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0))))
repr_test = transforms.PadTrim(max_len=length_new)
repr_test.__repr__()
def test_downmix_mono(self): def test_downmix_mono(self):
audio_L = self.sig.clone() audio_L = self.sig.clone()
...@@ -64,12 +70,18 @@ class Tester(unittest.TestCase): ...@@ -64,12 +70,18 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(1) == 1) self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono()
repr_test.__repr__()
def test_lc2cl(self): def test_lc2cl(self):
audio = self.sig.clone() audio = self.sig.clone()
result = transforms.LC2CL()(audio) result = transforms.LC2CL()(audio)
self.assertTrue(result.size()[::-1] == audio.size()) self.assertTrue(result.size()[::-1] == audio.size())
repr_test = transforms.LC2CL()
repr_test.__repr__()
def test_mel(self): def test_mel(self):
audio = self.sig.clone() audio = self.sig.clone()
...@@ -80,6 +92,11 @@ class Tester(unittest.TestCase): ...@@ -80,6 +92,11 @@ class Tester(unittest.TestCase):
result = transforms.BLC2CBL()(result) result = transforms.BLC2CBL()(result)
self.assertTrue(len(result.size()) == 3) self.assertTrue(len(result.size()) == 3)
repr_test = transforms.MEL()
repr_test.__repr__()
repr_test = transforms.BLC2CBL()
repr_test.__repr__()
def test_compose(self): def test_compose(self):
audio_orig = self.sig.clone() audio_orig = self.sig.clone()
...@@ -96,6 +113,9 @@ class Tester(unittest.TestCase): ...@@ -96,6 +113,9 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(0) == length_new) self.assertTrue(result.size(0) == length_new)
repr_test = transforms.Compose(tset)
repr_test.__repr__()
def test_mu_law_companding(self): def test_mu_law_companding(self):
sig = self.sig.clone() sig = self.sig.clone()
...@@ -121,6 +141,11 @@ class Tester(unittest.TestCase): ...@@ -121,6 +141,11 @@ class Tester(unittest.TestCase):
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu) sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.) self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
repr_test = transforms.MuLawEncoding(quantization_channels)
repr_test.__repr__()
repr_test = transforms.MuLawExpanding(quantization_channels)
repr_test.__repr__()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -28,6 +28,14 @@ class Compose(object): ...@@ -28,6 +28,14 @@ class Compose(object):
audio = t(audio) audio = t(audio)
return audio return audio
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class Scale(object): class Scale(object):
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor) """Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
...@@ -57,6 +65,9 @@ class Scale(object): ...@@ -57,6 +65,9 @@ class Scale(object):
return tensor / self.factor return tensor / self.factor
def __repr__(self):
return self.__class__.__name__ + '()'
class PadTrim(object): class PadTrim(object):
"""Pad/Trim a 1d-Tensor (Signal or Labels) """Pad/Trim a 1d-Tensor (Signal or Labels)
...@@ -87,6 +98,9 @@ class PadTrim(object): ...@@ -87,6 +98,9 @@ class PadTrim(object):
tensor = tensor[:self.max_len, :] tensor = tensor[:self.max_len, :]
return tensor return tensor
def __repr__(self):
return self.__class__.__name__ + '(max_len={0})'.format(self.max_len)
class DownmixMono(object): class DownmixMono(object):
"""Downmix any stereo signals to mono """Downmix any stereo signals to mono
...@@ -110,6 +124,9 @@ class DownmixMono(object): ...@@ -110,6 +124,9 @@ class DownmixMono(object):
tensor = torch.mean(tensor.float(), 1, True) tensor = torch.mean(tensor.float(), 1, True)
return tensor return tensor
def __repr__(self):
return self.__class__.__name__ + '()'
class LC2CL(object): class LC2CL(object):
"""Permute a 2d tensor from samples (Length) x Channels to Channels x """Permute a 2d tensor from samples (Length) x Channels to Channels x
...@@ -129,6 +146,9 @@ class LC2CL(object): ...@@ -129,6 +146,9 @@ class LC2CL(object):
return tensor.transpose(0, 1).contiguous() return tensor.transpose(0, 1).contiguous()
def __repr__(self):
return self.__class__.__name__ + '()'
class MEL(object): class MEL(object):
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow. """Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
...@@ -166,6 +186,9 @@ class MEL(object): ...@@ -166,6 +186,9 @@ class MEL(object):
return tensor return tensor
def __repr__(self):
return self.__class__.__name__ + '()'
class BLC2CBL(object): class BLC2CBL(object):
"""Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x """Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x
...@@ -185,6 +208,9 @@ class BLC2CBL(object): ...@@ -185,6 +208,9 @@ class BLC2CBL(object):
return tensor.permute(2, 0, 1).contiguous() return tensor.permute(2, 0, 1).contiguous()
def __repr__(self):
return self.__class__.__name__ + '()'
class MuLawEncoding(object): class MuLawEncoding(object):
"""Encode signal based on mu-law companding. For more info see the """Encode signal based on mu-law companding. For more info see the
...@@ -224,6 +250,9 @@ class MuLawEncoding(object): ...@@ -224,6 +250,9 @@ class MuLawEncoding(object):
x_mu = ((x_mu + 1) / 2 * mu + 0.5).long() x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
return x_mu return x_mu
def __repr__(self):
return self.__class__.__name__ + '()'
class MuLawExpanding(object): class MuLawExpanding(object):
"""Decode mu-law encoded signal. For more info see the """Decode mu-law encoded signal. For more info see the
...@@ -261,3 +290,6 @@ class MuLawExpanding(object): ...@@ -261,3 +290,6 @@ class MuLawExpanding(object):
x = ((x_mu) / mu) * 2 - 1. x = ((x_mu) / mu) * 2 - 1.
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
return x return x
def __repr__(self):
return self.__class__.__name__ + '()'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment