Unverified Commit e5c4de87 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Replace dtype if-elseif-else with switch (#1270)

parent d58ac213
......@@ -80,29 +80,37 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Convert to sox_sample_t (int32_t) and write to buffer
SOX_SAMPLE_LOCALS;
const auto dtype = tensor_.dtype();
if (dtype == torch::kFloat32) {
auto ptr = tensor_.data_ptr<float_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips);
switch (tensor_.dtype().toScalarType()) {
case c10::ScalarType::Float: {
auto ptr = tensor_.data_ptr<float_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips);
}
break;
}
} else if (dtype == torch::kInt32) {
auto ptr = tensor_.data_ptr<int32_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips);
case c10::ScalarType::Int: {
auto ptr = tensor_.data_ptr<int32_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips);
}
break;
}
} else if (dtype == torch::kInt16) {
auto ptr = tensor_.data_ptr<int16_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips);
case c10::ScalarType::Short: {
auto ptr = tensor_.data_ptr<int16_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips);
}
break;
}
} else if (dtype == torch::kUInt8) {
auto ptr = tensor_.data_ptr<uint8_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips);
case c10::ScalarType::Byte: {
auto ptr = tensor_.data_ptr<uint8_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips);
}
break;
}
} else {
throw std::runtime_error("Unexpected dtype.");
default:
throw std::runtime_error("Unexpected dtype.");
}
priv->index += *osamp;
return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS;
......
......@@ -102,11 +102,15 @@ void validate_input_tensor(const torch::Tensor tensor) {
throw std::runtime_error("Input tensor has to be 2D.");
}
const auto dtype = tensor.dtype();
if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 ||
dtype == torch::kInt16 || dtype == torch::kUInt8)) {
throw std::runtime_error(
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
switch (tensor.dtype().toScalarType()) {
case c10::ScalarType::Byte:
case c10::ScalarType::Short:
case c10::ScalarType::Int:
case c10::ScalarType::Float:
break;
default:
throw std::runtime_error(
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
}
}
......@@ -209,22 +213,25 @@ namespace {
std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
const std::string format,
const caffe2::TypeMeta dtype,
caffe2::TypeMeta dtype,
const Encoding& encoding,
const BitDepth& bits_per_sample) {
switch (encoding) {
case Encoding::NOT_PROVIDED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
if (dtype == torch::kFloat32)
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
if (dtype == torch::kInt32)
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
if (dtype == torch::kInt16)
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
if (dtype == torch::kUInt8)
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
throw std::runtime_error("Internal Error: Unexpected dtype.");
switch (dtype.toScalarType()) {
case c10::ScalarType::Float:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
case c10::ScalarType::Int:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
case c10::ScalarType::Short:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
case c10::ScalarType::Byte:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error("Internal Error: Unexpected dtype.");
}
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
......@@ -376,9 +383,7 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
}
}
unsigned get_precision(
const std::string filetype,
const caffe2::TypeMeta dtype) {
unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
if (filetype == "mp3")
return SOX_UNSPEC;
if (filetype == "flac")
......@@ -386,15 +391,18 @@ unsigned get_precision(
if (filetype == "ogg" || filetype == "vorbis")
return SOX_UNSPEC;
if (filetype == "wav" || filetype == "amb") {
if (dtype == torch::kUInt8)
return 8;
if (dtype == torch::kInt16)
return 16;
if (dtype == torch::kInt32)
return 32;
if (dtype == torch::kFloat32)
return 32;
throw std::runtime_error("Unsupported dtype.");
switch (dtype.toScalarType()) {
case c10::ScalarType::Byte:
return 8;
case c10::ScalarType::Short:
return 16;
case c10::ScalarType::Int:
return 32;
case c10::ScalarType::Float:
return 32;
default:
throw std::runtime_error("Unsupported dtype.");
}
}
if (filetype == "sph")
return 32;
......@@ -419,28 +427,34 @@ sox_signalinfo_t get_signalinfo(
/*length=*/static_cast<uint64_t>(waveform->numel())};
}
sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) {
sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
sox_encoding_t encoding = [&]() {
if (dtype == torch::kUInt8)
return SOX_ENCODING_UNSIGNED;
if (dtype == torch::kInt16)
return SOX_ENCODING_SIGN2;
if (dtype == torch::kInt32)
return SOX_ENCODING_SIGN2;
if (dtype == torch::kFloat32)
return SOX_ENCODING_FLOAT;
throw std::runtime_error("Unsupported dtype.");
switch (dtype.toScalarType()) {
case c10::ScalarType::Byte:
return SOX_ENCODING_UNSIGNED;
case c10::ScalarType::Short:
return SOX_ENCODING_SIGN2;
case c10::ScalarType::Int:
return SOX_ENCODING_SIGN2;
case c10::ScalarType::Float:
return SOX_ENCODING_FLOAT;
default:
throw std::runtime_error("Unsupported dtype.");
}
}();
unsigned bits_per_sample = [&]() {
if (dtype == torch::kUInt8)
return 8;
if (dtype == torch::kInt16)
return 16;
if (dtype == torch::kInt32)
return 32;
if (dtype == torch::kFloat32)
return 32;
throw std::runtime_error("Unsupported dtype.");
switch (dtype.toScalarType()) {
case c10::ScalarType::Byte:
return 8;
case c10::ScalarType::Short:
return 16;
case c10::ScalarType::Int:
return 32;
case c10::ScalarType::Float:
return 32;
default:
throw std::runtime_error("Unsupported dtype.");
}
}();
return sox_encodinginfo_t{
/*encoding=*/encoding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment