Unverified Commit e5c4de87 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Replace dtype if-elseif-else with switch (#1270)

parent d58ac213
...@@ -80,28 +80,36 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -80,28 +80,36 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Convert to sox_sample_t (int32_t) and write to buffer // Convert to sox_sample_t (int32_t) and write to buffer
SOX_SAMPLE_LOCALS; SOX_SAMPLE_LOCALS;
const auto dtype = tensor_.dtype(); switch (tensor_.dtype().toScalarType()) {
if (dtype == torch::kFloat32) { case c10::ScalarType::Float: {
auto ptr = tensor_.data_ptr<float_t>(); auto ptr = tensor_.data_ptr<float_t>();
for (size_t i = 0; i < *osamp; ++i) { for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips); obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips);
} }
} else if (dtype == torch::kInt32) { break;
}
case c10::ScalarType::Int: {
auto ptr = tensor_.data_ptr<int32_t>(); auto ptr = tensor_.data_ptr<int32_t>();
for (size_t i = 0; i < *osamp; ++i) { for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips); obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips);
} }
} else if (dtype == torch::kInt16) { break;
}
case c10::ScalarType::Short: {
auto ptr = tensor_.data_ptr<int16_t>(); auto ptr = tensor_.data_ptr<int16_t>();
for (size_t i = 0; i < *osamp; ++i) { for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips); obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips);
} }
} else if (dtype == torch::kUInt8) { break;
}
case c10::ScalarType::Byte: {
auto ptr = tensor_.data_ptr<uint8_t>(); auto ptr = tensor_.data_ptr<uint8_t>();
for (size_t i = 0; i < *osamp; ++i) { for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips); obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips);
} }
} else { break;
}
default:
throw std::runtime_error("Unexpected dtype."); throw std::runtime_error("Unexpected dtype.");
} }
priv->index += *osamp; priv->index += *osamp;
......
...@@ -102,9 +102,13 @@ void validate_input_tensor(const torch::Tensor tensor) { ...@@ -102,9 +102,13 @@ void validate_input_tensor(const torch::Tensor tensor) {
throw std::runtime_error("Input tensor has to be 2D."); throw std::runtime_error("Input tensor has to be 2D.");
} }
const auto dtype = tensor.dtype(); switch (tensor.dtype().toScalarType()) {
if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 || case c10::ScalarType::Byte:
dtype == torch::kInt16 || dtype == torch::kUInt8)) { case c10::ScalarType::Short:
case c10::ScalarType::Int:
case c10::ScalarType::Float:
break;
default:
throw std::runtime_error( throw std::runtime_error(
"Input tensor has to be one of float32, int32, int16 or uint8 type."); "Input tensor has to be one of float32, int32, int16 or uint8 type.");
} }
...@@ -209,22 +213,25 @@ namespace { ...@@ -209,22 +213,25 @@ namespace {
std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav( std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
const std::string format, const std::string format,
const caffe2::TypeMeta dtype, caffe2::TypeMeta dtype,
const Encoding& encoding, const Encoding& encoding,
const BitDepth& bits_per_sample) { const BitDepth& bits_per_sample) {
switch (encoding) { switch (encoding) {
case Encoding::NOT_PROVIDED: case Encoding::NOT_PROVIDED:
switch (bits_per_sample) { switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED: case BitDepth::NOT_PROVIDED:
if (dtype == torch::kFloat32) switch (dtype.toScalarType()) {
case c10::ScalarType::Float:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
if (dtype == torch::kInt32) case c10::ScalarType::Int:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
if (dtype == torch::kInt16) case c10::ScalarType::Short:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
if (dtype == torch::kUInt8) case c10::ScalarType::Byte:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error("Internal Error: Unexpected dtype."); throw std::runtime_error("Internal Error: Unexpected dtype.");
}
case BitDepth::B8: case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default: default:
...@@ -376,9 +383,7 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding( ...@@ -376,9 +383,7 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
} }
} }
unsigned get_precision( unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
const std::string filetype,
const caffe2::TypeMeta dtype) {
if (filetype == "mp3") if (filetype == "mp3")
return SOX_UNSPEC; return SOX_UNSPEC;
if (filetype == "flac") if (filetype == "flac")
...@@ -386,16 +391,19 @@ unsigned get_precision( ...@@ -386,16 +391,19 @@ unsigned get_precision(
if (filetype == "ogg" || filetype == "vorbis") if (filetype == "ogg" || filetype == "vorbis")
return SOX_UNSPEC; return SOX_UNSPEC;
if (filetype == "wav" || filetype == "amb") { if (filetype == "wav" || filetype == "amb") {
if (dtype == torch::kUInt8) switch (dtype.toScalarType()) {
case c10::ScalarType::Byte:
return 8; return 8;
if (dtype == torch::kInt16) case c10::ScalarType::Short:
return 16; return 16;
if (dtype == torch::kInt32) case c10::ScalarType::Int:
return 32; return 32;
if (dtype == torch::kFloat32) case c10::ScalarType::Float:
return 32; return 32;
default:
throw std::runtime_error("Unsupported dtype."); throw std::runtime_error("Unsupported dtype.");
} }
}
if (filetype == "sph") if (filetype == "sph")
return 32; return 32;
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
...@@ -419,28 +427,34 @@ sox_signalinfo_t get_signalinfo( ...@@ -419,28 +427,34 @@ sox_signalinfo_t get_signalinfo(
/*length=*/static_cast<uint64_t>(waveform->numel())}; /*length=*/static_cast<uint64_t>(waveform->numel())};
} }
sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) { sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
sox_encoding_t encoding = [&]() { sox_encoding_t encoding = [&]() {
if (dtype == torch::kUInt8) switch (dtype.toScalarType()) {
case c10::ScalarType::Byte:
return SOX_ENCODING_UNSIGNED; return SOX_ENCODING_UNSIGNED;
if (dtype == torch::kInt16) case c10::ScalarType::Short:
return SOX_ENCODING_SIGN2; return SOX_ENCODING_SIGN2;
if (dtype == torch::kInt32) case c10::ScalarType::Int:
return SOX_ENCODING_SIGN2; return SOX_ENCODING_SIGN2;
if (dtype == torch::kFloat32) case c10::ScalarType::Float:
return SOX_ENCODING_FLOAT; return SOX_ENCODING_FLOAT;
default:
throw std::runtime_error("Unsupported dtype."); throw std::runtime_error("Unsupported dtype.");
}
}(); }();
unsigned bits_per_sample = [&]() { unsigned bits_per_sample = [&]() {
if (dtype == torch::kUInt8) switch (dtype.toScalarType()) {
case c10::ScalarType::Byte:
return 8; return 8;
if (dtype == torch::kInt16) case c10::ScalarType::Short:
return 16; return 16;
if (dtype == torch::kInt32) case c10::ScalarType::Int:
return 32; return 32;
if (dtype == torch::kFloat32) case c10::ScalarType::Float:
return 32; return 32;
default:
throw std::runtime_error("Unsupported dtype."); throw std::runtime_error("Unsupported dtype.");
}
}(); }();
return sox_encodinginfo_t{ return sox_encodinginfo_t{
/*encoding=*/encoding, /*encoding=*/encoding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment