Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
e5c4de87
Unverified
Commit
e5c4de87
authored
Feb 16, 2021
by
Prabhat Roy
Committed by
GitHub
Feb 15, 2021
Browse files
Replace dtype if-elseif-else with switch (#1270)
parent
d58ac213
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
65 deletions
+87
-65
torchaudio/csrc/sox/effects_chain.cpp
torchaudio/csrc/sox/effects_chain.cpp
+27
-19
torchaudio/csrc/sox/utils.cpp
torchaudio/csrc/sox/utils.cpp
+60
-46
No files found.
torchaudio/csrc/sox/effects_chain.cpp
View file @
e5c4de87
...
...
@@ -80,28 +80,36 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Convert to sox_sample_t (int32_t) and write to buffer
SOX_SAMPLE_LOCALS
;
const
auto
dtype
=
tensor_
.
dtype
();
if
(
dtype
==
torch
::
k
Float
32
)
{
switch
(
tensor_
.
dtype
().
toScalarType
())
{
case
c10
::
ScalarType
::
Float
:
{
auto
ptr
=
tensor_
.
data_ptr
<
float_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_FLOAT_32BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
if
(
dtype
==
torch
::
kInt32
)
{
break
;
}
case
c10
::
ScalarType
::
Int
:
{
auto
ptr
=
tensor_
.
data_ptr
<
int32_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_SIGNED_32BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
if
(
dtype
==
torch
::
kInt16
)
{
break
;
}
case
c10
::
ScalarType
::
Short
:
{
auto
ptr
=
tensor_
.
data_ptr
<
int16_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_SIGNED_16BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
if
(
dtype
==
torch
::
kUInt8
)
{
break
;
}
case
c10
::
ScalarType
::
Byte
:
{
auto
ptr
=
tensor_
.
data_ptr
<
uint8_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_UNSIGNED_8BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
{
break
;
}
default:
throw
std
::
runtime_error
(
"Unexpected dtype."
);
}
priv
->
index
+=
*
osamp
;
...
...
torchaudio/csrc/sox/utils.cpp
View file @
e5c4de87
...
...
@@ -102,9 +102,13 @@ void validate_input_tensor(const torch::Tensor tensor) {
throw
std
::
runtime_error
(
"Input tensor has to be 2D."
);
}
const
auto
dtype
=
tensor
.
dtype
();
if
(
!
(
dtype
==
torch
::
kFloat32
||
dtype
==
torch
::
kInt32
||
dtype
==
torch
::
kInt16
||
dtype
==
torch
::
kUInt8
))
{
switch
(
tensor
.
dtype
().
toScalarType
())
{
case
c10
::
ScalarType
::
Byte
:
case
c10
::
ScalarType
::
Short
:
case
c10
::
ScalarType
::
Int
:
case
c10
::
ScalarType
::
Float
:
break
;
default:
throw
std
::
runtime_error
(
"Input tensor has to be one of float32, int32, int16 or uint8 type."
);
}
...
...
@@ -209,22 +213,25 @@ namespace {
std
::
tuple
<
sox_encoding_t
,
unsigned
>
get_save_encoding_for_wav
(
const
std
::
string
format
,
const
caffe2
::
TypeMeta
dtype
,
caffe2
::
TypeMeta
dtype
,
const
Encoding
&
encoding
,
const
BitDepth
&
bits_per_sample
)
{
switch
(
encoding
)
{
case
Encoding
::
NOT_PROVIDED
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
if
(
dtype
==
torch
::
kFloat32
)
switch
(
dtype
.
toScalarType
())
{
case
c10
::
ScalarType
::
Float
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_FLOAT
,
32
);
if
(
dtype
==
torch
::
k
Int
32
)
case
c10
::
ScalarType
::
Int
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
32
);
if
(
dtype
==
torch
::
kInt16
)
case
c10
::
ScalarType
::
Short
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
16
);
if
(
dtype
==
torch
::
kUInt8
)
case
c10
::
ScalarType
::
Byte
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_UNSIGNED
,
8
);
default:
throw
std
::
runtime_error
(
"Internal Error: Unexpected dtype."
);
}
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_UNSIGNED
,
8
);
default:
...
...
@@ -376,9 +383,7 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
}
}
unsigned
get_precision
(
const
std
::
string
filetype
,
const
caffe2
::
TypeMeta
dtype
)
{
unsigned
get_precision
(
const
std
::
string
filetype
,
caffe2
::
TypeMeta
dtype
)
{
if
(
filetype
==
"mp3"
)
return
SOX_UNSPEC
;
if
(
filetype
==
"flac"
)
...
...
@@ -386,16 +391,19 @@ unsigned get_precision(
if
(
filetype
==
"ogg"
||
filetype
==
"vorbis"
)
return
SOX_UNSPEC
;
if
(
filetype
==
"wav"
||
filetype
==
"amb"
)
{
if
(
dtype
==
torch
::
kUInt8
)
switch
(
dtype
.
toScalarType
())
{
case
c10
::
ScalarType
::
Byte
:
return
8
;
if
(
dtype
==
torch
::
kInt16
)
case
c10
::
ScalarType
::
Short
:
return
16
;
if
(
dtype
==
torch
::
k
Int
32
)
case
c10
::
ScalarType
::
Int
:
return
32
;
if
(
dtype
==
torch
::
k
Float
32
)
case
c10
::
ScalarType
::
Float
:
return
32
;
default:
throw
std
::
runtime_error
(
"Unsupported dtype."
);
}
}
if
(
filetype
==
"sph"
)
return
32
;
if
(
filetype
==
"amr-nb"
)
{
...
...
@@ -419,28 +427,34 @@ sox_signalinfo_t get_signalinfo(
/*length=*/
static_cast
<
uint64_t
>
(
waveform
->
numel
())};
}
sox_encodinginfo_t
get_tensor_encodinginfo
(
const
caffe2
::
TypeMeta
dtype
)
{
sox_encodinginfo_t
get_tensor_encodinginfo
(
caffe2
::
TypeMeta
dtype
)
{
sox_encoding_t
encoding
=
[
&
]()
{
if
(
dtype
==
torch
::
kUInt8
)
switch
(
dtype
.
toScalarType
())
{
case
c10
::
ScalarType
::
Byte
:
return
SOX_ENCODING_UNSIGNED
;
if
(
dtype
==
torch
::
kInt16
)
case
c10
::
ScalarType
::
Short
:
return
SOX_ENCODING_SIGN2
;
if
(
dtype
==
torch
::
k
Int
32
)
case
c10
::
ScalarType
::
Int
:
return
SOX_ENCODING_SIGN2
;
if
(
dtype
==
torch
::
k
Float
32
)
case
c10
::
ScalarType
::
Float
:
return
SOX_ENCODING_FLOAT
;
default:
throw
std
::
runtime_error
(
"Unsupported dtype."
);
}
}();
unsigned
bits_per_sample
=
[
&
]()
{
if
(
dtype
==
torch
::
kUInt8
)
switch
(
dtype
.
toScalarType
())
{
case
c10
::
ScalarType
::
Byte
:
return
8
;
if
(
dtype
==
torch
::
kInt16
)
case
c10
::
ScalarType
::
Short
:
return
16
;
if
(
dtype
==
torch
::
k
Int
32
)
case
c10
::
ScalarType
::
Int
:
return
32
;
if
(
dtype
==
torch
::
k
Float
32
)
case
c10
::
ScalarType
::
Float
:
return
32
;
default:
throw
std
::
runtime_error
(
"Unsupported dtype."
);
}
}();
return
sox_encodinginfo_t
{
/*encoding=*/
encoding
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment