Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ac0c94be
Commit
ac0c94be
authored
Jan 23, 2018
by
vishwakftw
Browse files
Add __repr__ for transforms and tests
parent
67564173
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
0 deletions
+57
-0
test/test_transforms.py
test/test_transforms.py
+25
-0
torchaudio/transforms.py
torchaudio/transforms.py
+32
-0
No files found.
test/test_transforms.py
View file @
ac0c94be
...
@@ -29,6 +29,9 @@ class Tester(unittest.TestCase):
...
@@ -29,6 +29,9 @@ class Tester(unittest.TestCase):
result
.
min
()
>=
-
1.
and
result
.
max
()
<=
1.
,
result
.
min
()
>=
-
1.
and
result
.
max
()
<=
1.
,
print
(
"min: {}, max: {}"
.
format
(
result
.
min
(),
result
.
max
())))
print
(
"min: {}, max: {}"
.
format
(
result
.
min
(),
result
.
max
())))
repr_test
=
transforms
.
Scale
()
repr_test
.
__repr__
()
def
test_pad_trim
(
self
):
def
test_pad_trim
(
self
):
audio_orig
=
self
.
sig
.
clone
()
audio_orig
=
self
.
sig
.
clone
()
...
@@ -49,6 +52,9 @@ class Tester(unittest.TestCase):
...
@@ -49,6 +52,9 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
size
(
0
)
==
length_new
,
self
.
assertTrue
(
result
.
size
(
0
)
==
length_new
,
print
(
"old size: {}, new size: {}"
.
format
(
audio_orig
.
size
(
0
),
result
.
size
(
0
))))
print
(
"old size: {}, new size: {}"
.
format
(
audio_orig
.
size
(
0
),
result
.
size
(
0
))))
repr_test
=
transforms
.
PadTrim
(
max_len
=
length_new
)
repr_test
.
__repr__
()
def
test_downmix_mono
(
self
):
def
test_downmix_mono
(
self
):
audio_L
=
self
.
sig
.
clone
()
audio_L
=
self
.
sig
.
clone
()
...
@@ -64,12 +70,18 @@ class Tester(unittest.TestCase):
...
@@ -64,12 +70,18 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
size
(
1
)
==
1
)
self
.
assertTrue
(
result
.
size
(
1
)
==
1
)
repr_test
=
transforms
.
DownmixMono
()
repr_test
.
__repr__
()
def
test_lc2cl
(
self
):
def
test_lc2cl
(
self
):
audio
=
self
.
sig
.
clone
()
audio
=
self
.
sig
.
clone
()
result
=
transforms
.
LC2CL
()(
audio
)
result
=
transforms
.
LC2CL
()(
audio
)
self
.
assertTrue
(
result
.
size
()[::
-
1
]
==
audio
.
size
())
self
.
assertTrue
(
result
.
size
()[::
-
1
]
==
audio
.
size
())
repr_test
=
transforms
.
LC2CL
()
repr_test
.
__repr__
()
def
test_mel
(
self
):
def
test_mel
(
self
):
audio
=
self
.
sig
.
clone
()
audio
=
self
.
sig
.
clone
()
...
@@ -80,6 +92,11 @@ class Tester(unittest.TestCase):
...
@@ -80,6 +92,11 @@ class Tester(unittest.TestCase):
result
=
transforms
.
BLC2CBL
()(
result
)
result
=
transforms
.
BLC2CBL
()(
result
)
self
.
assertTrue
(
len
(
result
.
size
())
==
3
)
self
.
assertTrue
(
len
(
result
.
size
())
==
3
)
repr_test
=
transforms
.
MEL
()
repr_test
.
__repr__
()
repr_test
=
transforms
.
BLC2CBL
()
repr_test
.
__repr__
()
def
test_compose
(
self
):
def
test_compose
(
self
):
audio_orig
=
self
.
sig
.
clone
()
audio_orig
=
self
.
sig
.
clone
()
...
@@ -96,6 +113,9 @@ class Tester(unittest.TestCase):
...
@@ -96,6 +113,9 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
size
(
0
)
==
length_new
)
self
.
assertTrue
(
result
.
size
(
0
)
==
length_new
)
repr_test
=
transforms
.
Compose
(
tset
)
repr_test
.
__repr__
()
def
test_mu_law_companding
(
self
):
def
test_mu_law_companding
(
self
):
sig
=
self
.
sig
.
clone
()
sig
=
self
.
sig
.
clone
()
...
@@ -121,6 +141,11 @@ class Tester(unittest.TestCase):
...
@@ -121,6 +141,11 @@ class Tester(unittest.TestCase):
sig_exp
=
transforms
.
MuLawExpanding
(
quantization_channels
)(
sig_mu
)
sig_exp
=
transforms
.
MuLawExpanding
(
quantization_channels
)(
sig_mu
)
self
.
assertTrue
(
sig_exp
.
min
()
>=
-
1.
and
sig_exp
.
max
()
<=
1.
)
self
.
assertTrue
(
sig_exp
.
min
()
>=
-
1.
and
sig_exp
.
max
()
<=
1.
)
repr_test
=
transforms
.
MuLawEncoding
(
quantization_channels
)
repr_test
.
__repr__
()
repr_test
=
transforms
.
MuLawExpanding
(
quantization_channels
)
repr_test
.
__repr__
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
torchaudio/transforms.py
View file @
ac0c94be
...
@@ -28,6 +28,14 @@ class Compose(object):
...
@@ -28,6 +28,14 @@ class Compose(object):
audio
=
t
(
audio
)
audio
=
t
(
audio
)
return
audio
return
audio
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
'('
for
t
in
self
.
transforms
:
format_string
+=
'
\n
'
format_string
+=
' {0}'
.
format
(
t
)
format_string
+=
'
\n
)'
return
format_string
class
Scale
(
object
):
class
Scale
(
object
):
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
...
@@ -57,6 +65,9 @@ class Scale(object):
...
@@ -57,6 +65,9 @@ class Scale(object):
return
tensor
/
self
.
factor
return
tensor
/
self
.
factor
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
class
PadTrim
(
object
):
class
PadTrim
(
object
):
"""Pad/Trim a 1d-Tensor (Signal or Labels)
"""Pad/Trim a 1d-Tensor (Signal or Labels)
...
@@ -87,6 +98,9 @@ class PadTrim(object):
...
@@ -87,6 +98,9 @@ class PadTrim(object):
tensor
=
tensor
[:
self
.
max_len
,
:]
tensor
=
tensor
[:
self
.
max_len
,
:]
return
tensor
return
tensor
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'(max_len={0})'
.
format
(
self
.
max_len
)
class
DownmixMono
(
object
):
class
DownmixMono
(
object
):
"""Downmix any stereo signals to mono
"""Downmix any stereo signals to mono
...
@@ -110,6 +124,9 @@ class DownmixMono(object):
...
@@ -110,6 +124,9 @@ class DownmixMono(object):
tensor
=
torch
.
mean
(
tensor
.
float
(),
1
,
True
)
tensor
=
torch
.
mean
(
tensor
.
float
(),
1
,
True
)
return
tensor
return
tensor
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
class
LC2CL
(
object
):
class
LC2CL
(
object
):
"""Permute a 2d tensor from samples (Length) x Channels to Channels x
"""Permute a 2d tensor from samples (Length) x Channels to Channels x
...
@@ -129,6 +146,9 @@ class LC2CL(object):
...
@@ -129,6 +146,9 @@ class LC2CL(object):
return
tensor
.
transpose
(
0
,
1
).
contiguous
()
return
tensor
.
transpose
(
0
,
1
).
contiguous
()
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
class
MEL
(
object
):
class
MEL
(
object
):
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
...
@@ -166,6 +186,9 @@ class MEL(object):
...
@@ -166,6 +186,9 @@ class MEL(object):
return
tensor
return
tensor
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
class
BLC2CBL
(
object
):
class
BLC2CBL
(
object
):
"""Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x
"""Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x
...
@@ -185,6 +208,9 @@ class BLC2CBL(object):
...
@@ -185,6 +208,9 @@ class BLC2CBL(object):
return
tensor
.
permute
(
2
,
0
,
1
).
contiguous
()
return
tensor
.
permute
(
2
,
0
,
1
).
contiguous
()
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
class
MuLawEncoding
(
object
):
class
MuLawEncoding
(
object
):
"""Encode signal based on mu-law companding. For more info see the
"""Encode signal based on mu-law companding. For more info see the
...
@@ -224,6 +250,9 @@ class MuLawEncoding(object):
...
@@ -224,6 +250,9 @@ class MuLawEncoding(object):
x_mu
=
((
x_mu
+
1
)
/
2
*
mu
+
0.5
).
long
()
x_mu
=
((
x_mu
+
1
)
/
2
*
mu
+
0.5
).
long
()
return
x_mu
return
x_mu
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
class
MuLawExpanding
(
object
):
class
MuLawExpanding
(
object
):
"""Decode mu-law encoded signal. For more info see the
"""Decode mu-law encoded signal. For more info see the
...
@@ -261,3 +290,6 @@ class MuLawExpanding(object):
...
@@ -261,3 +290,6 @@ class MuLawExpanding(object):
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
torch
.
sign
(
x
)
*
(
torch
.
exp
(
torch
.
abs
(
x
)
*
torch
.
log1p
(
mu
))
-
1.
)
/
mu
x
=
torch
.
sign
(
x
)
*
(
torch
.
exp
(
torch
.
abs
(
x
)
*
torch
.
log1p
(
mu
))
-
1.
)
/
mu
return
x
return
x
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment