Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
0c6ceb26
Unverified
Commit
0c6ceb26
authored
May 20, 2019
by
cpuhrsch
Committed by
GitHub
May 20, 2019
Browse files
Merge pull request #107 from jamarshon/numpyremove
Remove numpy support in transforms.py and functional.py
parents
00efbe61
51e27933
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
45 deletions
+28
-45
test/test_transforms.py
test/test_transforms.py
+0
-11
torchaudio/functional.py
torchaudio/functional.py
+22
-28
torchaudio/transforms.py
torchaudio/transforms.py
+6
-6
No files found.
test/test_transforms.py
View file @
0c6ceb26
...
@@ -107,18 +107,7 @@ class Tester(unittest.TestCase):
...
@@ -107,18 +107,7 @@ class Tester(unittest.TestCase):
def
test_mu_law_companding
(
self
):
def
test_mu_law_companding
(
self
):
sig
=
self
.
sig
.
clone
()
quantization_channels
=
256
quantization_channels
=
256
sig
=
self
.
sig
.
numpy
()
sig
=
sig
/
np
.
abs
(
sig
).
max
()
self
.
assertTrue
(
sig
.
min
()
>=
-
1.
and
sig
.
max
()
<=
1.
)
sig_mu
=
transforms
.
MuLawEncoding
(
quantization_channels
)(
sig
)
self
.
assertTrue
(
sig_mu
.
min
()
>=
0.
and
sig
.
max
()
<=
quantization_channels
)
sig_exp
=
transforms
.
MuLawExpanding
(
quantization_channels
)(
sig_mu
)
self
.
assertTrue
(
sig_exp
.
min
()
>=
-
1.
and
sig_exp
.
max
()
<=
1.
)
sig
=
self
.
sig
.
clone
()
sig
=
self
.
sig
.
clone
()
sig
=
sig
/
torch
.
abs
(
sig
).
max
()
sig
=
sig
/
torch
.
abs
(
sig
).
max
()
...
...
torchaudio/functional.py
View file @
0c6ceb26
import
numpy
as
np
import
math
import
torch
import
torch
...
@@ -245,15 +245,15 @@ def create_dct(n_mfcc, n_mels, norm):
...
@@ -245,15 +245,15 @@ def create_dct(n_mfcc, n_mels, norm):
outdim
=
n_mfcc
outdim
=
n_mfcc
dim
=
n_mels
dim
=
n_mels
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n
=
np
.
arange
(
dim
)
n
=
torch
.
arange
(
dim
,
dtype
=
torch
.
get_default_dtype
()
)
k
=
np
.
arange
(
outdim
)[:,
np
.
newaxis
]
k
=
torch
.
arange
(
outdim
,
dtype
=
torch
.
get_default_dtype
())[:,
None
]
dct
=
np
.
cos
(
np
.
pi
/
dim
*
(
n
+
0.5
)
*
k
)
dct
=
torch
.
cos
(
math
.
pi
/
dim
*
(
n
+
0.5
)
*
k
)
if
norm
==
'ortho'
:
if
norm
==
'ortho'
:
dct
[
0
]
*=
1.0
/
np
.
sqrt
(
2
)
dct
[
0
]
*=
1.0
/
math
.
sqrt
(
2
.0
)
dct
*=
np
.
sqrt
(
2.0
/
dim
)
dct
*=
math
.
sqrt
(
2.0
/
dim
)
else
:
else
:
dct
*=
2
dct
*=
2
return
torch
.
Tensor
(
dct
.
T
)
return
dct
.
t
(
)
def
MFCC
(
sig
,
mel_spect
,
log_mels
,
s2db
,
dct_mat
):
def
MFCC
(
sig
,
mel_spect
,
log_mels
,
s2db
,
dct_mat
):
...
@@ -302,7 +302,7 @@ def BLC2CBL(tensor):
...
@@ -302,7 +302,7 @@ def BLC2CBL(tensor):
def
mu_law_encoding
(
x
,
qc
):
def
mu_law_encoding
(
x
,
qc
):
# type: (Tensor
/ndarray
, int) -> Tensor
/ndarray
# type: (Tensor, int) -> Tensor
"""Encode signal based on mu-law companding. For more info see the
"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...
@@ -316,22 +316,19 @@ def mu_law_encoding(x, qc):
...
@@ -316,22 +316,19 @@ def mu_law_encoding(x, qc):
Outputs:
Outputs:
Tensor: Input after mu-law companding
Tensor: Input after mu-law companding
"""
"""
assert
isinstance
(
x
,
torch
.
Tensor
),
'mu_law_encoding expects a Tensor'
mu
=
qc
-
1.
mu
=
qc
-
1.
if
isinstance
(
x
,
np
.
ndarray
):
if
not
x
.
dtype
.
is_floating_point
:
x_mu
=
np
.
sign
(
x
)
*
np
.
log1p
(
mu
*
np
.
abs
(
x
))
/
np
.
log1p
(
mu
)
x
=
x
.
to
(
torch
.
float
)
x_mu
=
((
x_mu
+
1
)
/
2
*
mu
+
0.5
).
astype
(
int
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x
.
dtype
)
elif
isinstance
(
x
,
torch
.
Tensor
):
x_mu
=
torch
.
sign
(
x
)
*
torch
.
log1p
(
mu
*
if
not
x
.
dtype
.
is_floating_point
:
torch
.
abs
(
x
))
/
torch
.
log1p
(
mu
)
x
=
x
.
to
(
torch
.
float
)
x_mu
=
((
x_mu
+
1
)
/
2
*
mu
+
0.5
).
to
(
torch
.
int64
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x
.
dtype
)
x_mu
=
torch
.
sign
(
x
)
*
torch
.
log1p
(
mu
*
torch
.
abs
(
x
))
/
torch
.
log1p
(
mu
)
x_mu
=
((
x_mu
+
1
)
/
2
*
mu
+
0.5
).
long
()
return
x_mu
return
x_mu
def
mu_law_expanding
(
x_mu
,
qc
):
def
mu_law_expanding
(
x_mu
,
qc
):
# type: (Tensor
/ndarray
, int) -> Tensor
/ndarray
# type: (Tensor, int) -> Tensor
"""Decode mu-law encoded signal. For more info see the
"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...
@@ -345,14 +342,11 @@ def mu_law_expanding(x_mu, qc):
...
@@ -345,14 +342,11 @@ def mu_law_expanding(x_mu, qc):
Outputs:
Outputs:
Tensor: Input after decoding
Tensor: Input after decoding
"""
"""
assert
isinstance
(
x_mu
,
torch
.
Tensor
),
'mu_law_expanding expects a Tensor'
mu
=
qc
-
1.
mu
=
qc
-
1.
if
isinstance
(
x_mu
,
np
.
ndarray
):
if
not
x_mu
.
dtype
.
is_floating_point
:
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x_mu
=
x_mu
.
to
(
torch
.
float
)
x
=
np
.
sign
(
x
)
*
(
np
.
exp
(
np
.
abs
(
x
)
*
np
.
log1p
(
mu
))
-
1.
)
/
mu
mu
=
torch
.
tensor
(
mu
,
dtype
=
x_mu
.
dtype
)
elif
isinstance
(
x_mu
,
torch
.
Tensor
):
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
if
not
x_mu
.
dtype
.
is_floating_point
:
x
=
torch
.
sign
(
x
)
*
(
torch
.
exp
(
torch
.
abs
(
x
)
*
torch
.
log1p
(
mu
))
-
1.
)
/
mu
x_mu
=
x_mu
.
to
(
torch
.
float
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x_mu
.
dtype
)
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
torch
.
sign
(
x
)
*
(
torch
.
exp
(
torch
.
abs
(
x
)
*
torch
.
log1p
(
mu
))
-
1.
)
/
mu
return
x
return
x
torchaudio/transforms.py
View file @
0c6ceb26
from
__future__
import
division
,
print_function
from
__future__
import
division
,
print_function
from
warnings
import
warn
from
warnings
import
warn
import
math
import
torch
import
torch
import
numpy
as
np
from
.
import
functional
as
F
from
.
import
functional
as
F
...
@@ -234,7 +234,7 @@ class SpectrogramToDB(object):
...
@@ -234,7 +234,7 @@ class SpectrogramToDB(object):
self
.
multiplier
=
10.
if
stype
==
"power"
else
20.
self
.
multiplier
=
10.
if
stype
==
"power"
else
20.
self
.
amin
=
1e-10
self
.
amin
=
1e-10
self
.
ref_value
=
1.
self
.
ref_value
=
1.
self
.
db_multiplier
=
np
.
log10
(
np
.
maximum
(
self
.
amin
,
self
.
ref_value
))
self
.
db_multiplier
=
math
.
log10
(
max
(
self
.
amin
,
self
.
ref_value
))
def
__call__
(
self
,
spec
):
def
__call__
(
self
,
spec
):
# numerically stable implementation from librosa
# numerically stable implementation from librosa
...
@@ -403,10 +403,10 @@ class MuLawEncoding(object):
...
@@ -403,10 +403,10 @@ class MuLawEncoding(object):
"""
"""
Args:
Args:
x (FloatTensor/LongTensor
or ndarray
)
x (FloatTensor/LongTensor)
Returns:
Returns:
x_mu (LongTensor
or ndarray
)
x_mu (LongTensor)
"""
"""
return
F
.
mu_law_encoding
(
x
,
self
.
qc
)
return
F
.
mu_law_encoding
(
x
,
self
.
qc
)
...
@@ -434,10 +434,10 @@ class MuLawExpanding(object):
...
@@ -434,10 +434,10 @@ class MuLawExpanding(object):
"""
"""
Args:
Args:
x_mu (
Float
Tensor
/LongTensor or ndarray
)
x_mu (Tensor)
Returns:
Returns:
x (
Float
Tensor
or ndarray
)
x (Tensor)
"""
"""
return
F
.
mu_law_expanding
(
x_mu
,
self
.
qc
)
return
F
.
mu_law_expanding
(
x_mu
,
self
.
qc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment