Unverified Commit e5c4de87 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Replace dtype if-elseif-else with switch (#1270)

parent d58ac213
...@@ -80,29 +80,37 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -80,29 +80,37 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Convert to sox_sample_t (int32_t) and write to buffer // Convert to sox_sample_t (int32_t) and write to buffer
SOX_SAMPLE_LOCALS; SOX_SAMPLE_LOCALS;
const auto dtype = tensor_.dtype(); switch (tensor_.dtype().toScalarType()) {
if (dtype == torch::kFloat32) { case c10::ScalarType::Float: {
auto ptr = tensor_.data_ptr<float_t>(); auto ptr = tensor_.data_ptr<float_t>();
for (size_t i = 0; i < *osamp; ++i) { for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips); obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips);
}
break;
} }
} else if (dtype == torch::kInt32) { case c10::ScalarType::Int: {
auto ptr = tensor_.data_ptr<int32_t>(); auto ptr = tensor_.data_ptr<int32_t>();
for (size_t i = 0; i < *osamp; ++i) { for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips); obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips);
}
break;
} }
} else if (dtype == torch::kInt16) { case c10::ScalarType::Short: {
auto ptr = tensor_.data_ptr<int16_t>(); auto ptr = tensor_.data_ptr<int16_t>();
for (size_t i = 0; i < *osamp; ++i) { for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips); obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips);
}
break;
} }
} else if (dtype == torch::kUInt8) { case c10::ScalarType::Byte: {
auto ptr = tensor_.data_ptr<uint8_t>(); auto ptr = tensor_.data_ptr<uint8_t>();
for (size_t i = 0; i < *osamp; ++i) { for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips); obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips);
}
break;
} }
} else { default:
throw std::runtime_error("Unexpected dtype."); throw std::runtime_error("Unexpected dtype.");
} }
priv->index += *osamp; priv->index += *osamp;
return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS; return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS;
......
...@@ -102,11 +102,15 @@ void validate_input_tensor(const torch::Tensor tensor) { ...@@ -102,11 +102,15 @@ void validate_input_tensor(const torch::Tensor tensor) {
throw std::runtime_error("Input tensor has to be 2D."); throw std::runtime_error("Input tensor has to be 2D.");
} }
const auto dtype = tensor.dtype(); switch (tensor.dtype().toScalarType()) {
if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 || case c10::ScalarType::Byte:
dtype == torch::kInt16 || dtype == torch::kUInt8)) { case c10::ScalarType::Short:
throw std::runtime_error( case c10::ScalarType::Int:
"Input tensor has to be one of float32, int32, int16 or uint8 type."); case c10::ScalarType::Float:
break;
default:
throw std::runtime_error(
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
} }
} }
...@@ -209,22 +213,25 @@ namespace { ...@@ -209,22 +213,25 @@ namespace {
std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav( std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
const std::string format, const std::string format,
const caffe2::TypeMeta dtype, caffe2::TypeMeta dtype,
const Encoding& encoding, const Encoding& encoding,
const BitDepth& bits_per_sample) { const BitDepth& bits_per_sample) {
switch (encoding) { switch (encoding) {
case Encoding::NOT_PROVIDED: case Encoding::NOT_PROVIDED:
switch (bits_per_sample) { switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED: case BitDepth::NOT_PROVIDED:
if (dtype == torch::kFloat32) switch (dtype.toScalarType()) {
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); case c10::ScalarType::Float:
if (dtype == torch::kInt32) return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); case c10::ScalarType::Int:
if (dtype == torch::kInt16) return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); case c10::ScalarType::Short:
if (dtype == torch::kUInt8) return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); case c10::ScalarType::Byte:
throw std::runtime_error("Internal Error: Unexpected dtype."); return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error("Internal Error: Unexpected dtype.");
}
case BitDepth::B8: case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default: default:
...@@ -376,9 +383,7 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding( ...@@ -376,9 +383,7 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
} }
} }
unsigned get_precision( unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
const std::string filetype,
const caffe2::TypeMeta dtype) {
if (filetype == "mp3") if (filetype == "mp3")
return SOX_UNSPEC; return SOX_UNSPEC;
if (filetype == "flac") if (filetype == "flac")
...@@ -386,15 +391,18 @@ unsigned get_precision( ...@@ -386,15 +391,18 @@ unsigned get_precision(
if (filetype == "ogg" || filetype == "vorbis") if (filetype == "ogg" || filetype == "vorbis")
return SOX_UNSPEC; return SOX_UNSPEC;
if (filetype == "wav" || filetype == "amb") { if (filetype == "wav" || filetype == "amb") {
if (dtype == torch::kUInt8) switch (dtype.toScalarType()) {
return 8; case c10::ScalarType::Byte:
if (dtype == torch::kInt16) return 8;
return 16; case c10::ScalarType::Short:
if (dtype == torch::kInt32) return 16;
return 32; case c10::ScalarType::Int:
if (dtype == torch::kFloat32) return 32;
return 32; case c10::ScalarType::Float:
throw std::runtime_error("Unsupported dtype."); return 32;
default:
throw std::runtime_error("Unsupported dtype.");
}
} }
if (filetype == "sph") if (filetype == "sph")
return 32; return 32;
...@@ -419,28 +427,34 @@ sox_signalinfo_t get_signalinfo( ...@@ -419,28 +427,34 @@ sox_signalinfo_t get_signalinfo(
/*length=*/static_cast<uint64_t>(waveform->numel())}; /*length=*/static_cast<uint64_t>(waveform->numel())};
} }
sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) { sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
sox_encoding_t encoding = [&]() { sox_encoding_t encoding = [&]() {
if (dtype == torch::kUInt8) switch (dtype.toScalarType()) {
return SOX_ENCODING_UNSIGNED; case c10::ScalarType::Byte:
if (dtype == torch::kInt16) return SOX_ENCODING_UNSIGNED;
return SOX_ENCODING_SIGN2; case c10::ScalarType::Short:
if (dtype == torch::kInt32) return SOX_ENCODING_SIGN2;
return SOX_ENCODING_SIGN2; case c10::ScalarType::Int:
if (dtype == torch::kFloat32) return SOX_ENCODING_SIGN2;
return SOX_ENCODING_FLOAT; case c10::ScalarType::Float:
throw std::runtime_error("Unsupported dtype."); return SOX_ENCODING_FLOAT;
default:
throw std::runtime_error("Unsupported dtype.");
}
}(); }();
unsigned bits_per_sample = [&]() { unsigned bits_per_sample = [&]() {
if (dtype == torch::kUInt8) switch (dtype.toScalarType()) {
return 8; case c10::ScalarType::Byte:
if (dtype == torch::kInt16) return 8;
return 16; case c10::ScalarType::Short:
if (dtype == torch::kInt32) return 16;
return 32; case c10::ScalarType::Int:
if (dtype == torch::kFloat32) return 32;
return 32; case c10::ScalarType::Float:
throw std::runtime_error("Unsupported dtype."); return 32;
default:
throw std::runtime_error("Unsupported dtype.");
}
}(); }();
return sox_encodinginfo_t{ return sox_encodinginfo_t{
/*encoding=*/encoding, /*encoding=*/encoding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment