"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "3645e493e374f85e90430c7ae13c0575a1f7c7fb"
Unverified Commit 135e966d authored by moto's avatar moto Committed by GitHub
Browse files

Refactor `get_encodinginfo` logic (#1233)

* Distinguish get_encodinginfo for Tensor I/O and save output

* Isolate get_tensor_encodinginfo so as not to use the same helper function
parent 8b93bd68
...@@ -60,8 +60,8 @@ std::tuple<torch::Tensor, int64_t> apply_effects_tensor( ...@@ -60,8 +60,8 @@ std::tuple<torch::Tensor, int64_t> apply_effects_tensor(
// Create SoxEffectsChain // Create SoxEffectsChain
const auto dtype = waveform.dtype(); const auto dtype = waveform.dtype();
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", dtype), /*input_encoding=*/get_tensor_encodinginfo(dtype),
/*output_encoding=*/get_encodinginfo("wav", dtype)); /*output_encoding=*/get_tensor_encodinginfo(dtype));
// Prepare output buffer // Prepare output buffer
std::vector<sox_sample_t> out_buffer; std::vector<sox_sample_t> out_buffer;
...@@ -112,7 +112,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_file( ...@@ -112,7 +112,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_file(
// Create and run SoxEffectsChain // Create and run SoxEffectsChain
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding, /*input_encoding=*/sf->encoding,
/*output_encoding=*/get_encodinginfo("wav", dtype)); /*output_encoding=*/get_tensor_encodinginfo(dtype));
chain.addInputFile(sf); chain.addInputFile(sf);
for (const auto& effect : effects) { for (const auto& effect : effects) {
...@@ -214,7 +214,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj( ...@@ -214,7 +214,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding, /*input_encoding=*/sf->encoding,
/*output_encoding=*/get_encodinginfo("wav", dtype)); /*output_encoding=*/get_tensor_encodinginfo(dtype));
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj); chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
for (const auto& effect : effects) { for (const auto& effect : effects) {
chain.addEffect(effect); chain.addEffect(effect);
......
...@@ -143,7 +143,8 @@ void save_audio_file( ...@@ -143,7 +143,8 @@ void save_audio_file(
} }
const auto signal_info = const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first); get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression); const auto encoding_info =
get_encodinginfo_for_save(filetype, tgt_dtype, compression);
SoxFormat sf(sox_open_write( SoxFormat sf(sox_open_write(
path.c_str(), path.c_str(),
...@@ -158,7 +159,7 @@ void save_audio_file( ...@@ -158,7 +159,7 @@ void save_audio_file(
} }
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", tensor.dtype()), /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding); /*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first); chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFile(sf); chain.addOutputFile(sf);
...@@ -281,7 +282,8 @@ void save_audio_fileobj( ...@@ -281,7 +282,8 @@ void save_audio_fileobj(
} }
const auto signal_info = const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first); get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression); const auto encoding_info =
get_encodinginfo_for_save(filetype, tgt_dtype, compression);
AutoReleaseBuffer buffer; AutoReleaseBuffer buffer;
...@@ -299,7 +301,7 @@ void save_audio_fileobj( ...@@ -299,7 +301,7 @@ void save_audio_fileobj(
} }
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", tensor.dtype()), /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding); /*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first); chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj); chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj);
......
...@@ -291,12 +291,32 @@ sox_signalinfo_t get_signalinfo( ...@@ -291,12 +291,32 @@ sox_signalinfo_t get_signalinfo(
/*length=*/static_cast<uint64_t>(waveform->numel())}; /*length=*/static_cast<uint64_t>(waveform->numel())};
} }
sox_encodinginfo_t get_encodinginfo( sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) {
const std::string filetype, sox_encoding_t encoding = [&]() {
const caffe2::TypeMeta dtype) { if (dtype == torch::kUInt8)
return SOX_ENCODING_UNSIGNED;
if (dtype == torch::kInt16)
return SOX_ENCODING_SIGN2;
if (dtype == torch::kInt32)
return SOX_ENCODING_SIGN2;
if (dtype == torch::kFloat32)
return SOX_ENCODING_FLOAT;
throw std::runtime_error("Unsupported dtype.");
}();
unsigned bits_per_sample = [&]() {
if (dtype == torch::kUInt8)
return 8;
if (dtype == torch::kInt16)
return 16;
if (dtype == torch::kInt32)
return 32;
if (dtype == torch::kFloat32)
return 32;
throw std::runtime_error("Unsupported dtype.");
}();
return sox_encodinginfo_t{ return sox_encodinginfo_t{
/*encoding=*/get_encoding(filetype, dtype), /*encoding=*/encoding,
/*bits_per_sample=*/get_precision(filetype, dtype), /*bits_per_sample=*/bits_per_sample,
/*compression=*/HUGE_VAL, /*compression=*/HUGE_VAL,
/*reverse_bytes=*/sox_option_default, /*reverse_bytes=*/sox_option_default,
/*reverse_nibbles=*/sox_option_default, /*reverse_nibbles=*/sox_option_default,
...@@ -304,7 +324,7 @@ sox_encodinginfo_t get_encodinginfo( ...@@ -304,7 +324,7 @@ sox_encodinginfo_t get_encodinginfo(
/*opposite_endian=*/sox_false}; /*opposite_endian=*/sox_false};
} }
sox_encodinginfo_t get_encodinginfo( sox_encodinginfo_t get_encodinginfo_for_save(
const std::string filetype, const std::string filetype,
const caffe2::TypeMeta dtype, const caffe2::TypeMeta dtype,
c10::optional<double>& compression) { c10::optional<double>& compression) {
......
...@@ -108,12 +108,11 @@ sox_signalinfo_t get_signalinfo( ...@@ -108,12 +108,11 @@ sox_signalinfo_t get_signalinfo(
const std::string filetype, const std::string filetype,
const bool channels_first); const bool channels_first);
/// Get sox_encofinginfo_t for saving audoi file /// Get sox_encodinginfo_t for Tensor I/O
sox_encodinginfo_t get_encodinginfo( sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype);
const std::string filetype,
const caffe2::TypeMeta dtype);
sox_encodinginfo_t get_encodinginfo( /// Get sox_encodinginfo_t for saving to file/file object
sox_encodinginfo_t get_encodinginfo_for_save(
const std::string filetype, const std::string filetype,
const caffe2::TypeMeta dtype, const caffe2::TypeMeta dtype,
c10::optional<double>& compression); c10::optional<double>& compression);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment