Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
9edce71c
Commit
9edce71c
authored
Sep 18, 2017
by
Soumith Chintala
Committed by
GitHub
Sep 18, 2017
Browse files
Merge pull request #16 from dhpollack/mu_law_companding
mu-law companding transform
parents
9538c65f
38de2cb6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
113 additions
and
14 deletions
+113
-14
test/test_transforms.py
test/test_transforms.py
+28
-0
torchaudio/transforms.py
torchaudio/transforms.py
+85
-14
No files found.
test/test_transforms.py
View file @
9edce71c
...
...
@@ -93,6 +93,34 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
size
(
0
)
==
length_new
)
def
test_mu_law_companding
(
self
):
sig
=
self
.
sig
.
clone
()
quantization_channels
=
256
sig
=
self
.
sig
.
numpy
()
sig
=
sig
/
np
.
abs
(
sig
).
max
()
self
.
assertTrue
(
sig
.
min
()
>=
-
1.
and
sig
.
max
()
<=
1.
)
sig_mu
=
transforms
.
MuLawEncoding
(
quantization_channels
)(
sig
)
self
.
assertTrue
(
sig_mu
.
min
()
>=
0.
and
sig
.
max
()
<=
quantization_channels
)
sig_exp
=
transforms
.
MuLawExpanding
(
quantization_channels
)(
sig_mu
)
self
.
assertTrue
(
sig_exp
.
min
()
>=
-
1.
and
sig_exp
.
max
()
<=
1.
)
#diff = sig - sig_exp
#mse = np.linalg.norm(diff) / diff.shape[0]
#self.assertTrue(mse, np.isclose(mse, 0., atol=1e-4)) # not always true
sig
=
self
.
sig
.
clone
()
sig
=
sig
/
torch
.
abs
(
sig
).
max
()
self
.
assertTrue
(
sig
.
min
()
>=
-
1.
and
sig
.
max
()
<=
1.
)
sig_mu
=
transforms
.
MuLawEncoding
(
quantization_channels
)(
sig
)
self
.
assertTrue
(
sig_mu
.
min
()
>=
0.
and
sig
.
max
()
<=
quantization_channels
)
sig_exp
=
transforms
.
MuLawExpanding
(
quantization_channels
)(
sig_mu
)
self
.
assertTrue
(
sig_exp
.
min
()
>=
-
1.
and
sig_exp
.
max
()
<=
1.
)
if
__name__
==
'__main__'
:
unittest
.
main
()
torchaudio/transforms.py
View file @
9edce71c
...
...
@@ -33,7 +33,7 @@ class Scale(object):
called the "bit depth" or "precision", not to be confused with "bit rate".
Args:
factor (
floa
t): maximum value of input tensor. default: 16-bit depth
factor (
in
t): maximum value of input tensor. default: 16-bit depth
"""
...
...
@@ -58,6 +58,10 @@ class Scale(object):
class
PadTrim
(
object
):
"""Pad/Trim a 1d-Tensor (Signal or Labels)
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
max_len (int): Length to which the tensor will be padded
"""
def
__init__
(
self
,
max_len
,
fill_value
=
0
):
...
...
@@ -67,10 +71,6 @@ class PadTrim(object):
def
__call__
(
self
,
tensor
):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
max_len (int): Length to which the tensor will be padded
Returns:
Tensor: (max_len x Channels)
...
...
@@ -88,21 +88,18 @@ class PadTrim(object):
class
DownmixMono
(
object
):
"""Downmix any stereo signals to mono
Inputs:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
tensor (Tensor) (Samples x 1):
"""
def
__init__
(
self
):
pass
def
__call__
(
self
,
tensor
):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
Tensor: (Samples x 1)
"""
if
isinstance
(
tensor
,
(
torch
.
LongTensor
,
torch
.
IntTensor
)):
tensor
=
tensor
.
float
()
...
...
@@ -181,3 +178,77 @@ class BLC2CBL(object):
"""
return
tensor
.
permute
(
2
,
0
,
1
).
contiguous
()
class
MuLawEncoding
(
object
):
"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This algorithm assumes the signal has been scaled to between -1 and 1 and
returns a signal encoded with values from 0 to quantization_channels - 1
Args:
quantization_channels (int): Number of channels. default: 256
"""
def
__init__
(
self
,
quantization_channels
=
256
):
self
.
qc
=
quantization_channels
def
__call__
(
self
,
x
):
"""
Args:
x (FloatTensor/LongTensor or ndarray)
Returns:
x_mu (LongTensor or ndarray)
"""
mu
=
self
.
qc
-
1.
if
isinstance
(
x
,
np
.
ndarray
):
x_mu
=
np
.
sign
(
x
)
*
np
.
log1p
(
mu
*
np
.
abs
(
x
))
/
np
.
log1p
(
mu
)
x_mu
=
((
x_mu
+
1
)
/
2
*
mu
+
0.5
).
astype
(
int
)
elif
isinstance
(
x
,
(
torch
.
Tensor
,
torch
.
LongTensor
)):
if
isinstance
(
x
,
torch
.
LongTensor
):
x
=
x
.
float
()
mu
=
torch
.
FloatTensor
([
mu
])
x_mu
=
torch
.
sign
(
x
)
*
torch
.
log1p
(
mu
*
torch
.
abs
(
x
))
/
torch
.
log1p
(
mu
)
x_mu
=
((
x_mu
+
1
)
/
2
*
mu
+
0.5
).
long
()
return
x_mu
class
MuLawExpanding
(
object
):
"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This expects an input with values between 0 and quantization_channels - 1
and returns a signal scaled between -1 and 1.
Args:
quantization_channels (int): Number of channels. default: 256
"""
def
__init__
(
self
,
quantization_channels
=
256
):
self
.
qc
=
quantization_channels
def
__call__
(
self
,
x_mu
):
"""
Args:
x_mu (FloatTensor/LongTensor or ndarray)
Returns:
x (FloatTensor or ndarray)
"""
mu
=
self
.
qc
-
1.
if
isinstance
(
x_mu
,
np
.
ndarray
):
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
np
.
sign
(
x
)
*
(
np
.
exp
(
np
.
abs
(
x
)
*
np
.
log1p
(
mu
))
-
1.
)
/
mu
elif
isinstance
(
x_mu
,
(
torch
.
Tensor
,
torch
.
LongTensor
)):
if
isinstance
(
x_mu
,
torch
.
LongTensor
):
x_mu
=
x_mu
.
float
()
mu
=
torch
.
FloatTensor
([
mu
])
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
torch
.
sign
(
x
)
*
(
torch
.
exp
(
torch
.
abs
(
x
)
*
torch
.
log1p
(
mu
))
-
1.
)
/
mu
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment