Unverified Commit d947dee0 authored by moto-meta's avatar moto-meta Committed by GitHub
Browse files

Resolve lint issues

Differential Revision: D50205775

Pull Request resolved: https://github.com/pytorch/audio/pull/3651
parent 57f7f522
...@@ -52,7 +52,6 @@ std::tuple<size_t, int> calculate_require_buff_and_init_internal_data( ...@@ -52,7 +52,6 @@ std::tuple<size_t, int> calculate_require_buff_and_init_internal_data(
float threshold); float threshold);
int ctc_beam_search_decoder_batch_gpu( int ctc_beam_search_decoder_batch_gpu(
InternalData* inter_data, InternalData* inter_data,
float* pp,
int blid, int blid,
int spid, int spid,
int* clist, int* clist,
......
...@@ -24,7 +24,7 @@ constexpr __device__ IntType log2(IntType num, IntType ret = IntType(0)) { ...@@ -24,7 +24,7 @@ constexpr __device__ IntType log2(IntType num, IntType ret = IntType(0)) {
*/ */
template <auto Value_> template <auto Value_>
struct Pow2 { struct Pow2 {
typedef decltype(Value_) Type; using Type = decltype(Value_);
static constexpr Type Value = Value_; static constexpr Type Value = Value_;
static constexpr Type Log2 = log2(Value); static constexpr Type Log2 = log2(Value);
static constexpr Type Mask = Value - 1; static constexpr Type Mask = Value - 1;
......
...@@ -72,8 +72,9 @@ std::tuple<size_t, int> calculate_require_buff_and_init_internal_data( ...@@ -72,8 +72,9 @@ std::tuple<size_t, int> calculate_require_buff_and_init_internal_data(
const std::vector<int>& prob_strides, const std::vector<int>& prob_strides,
int blid, int blid,
float threshold) { float threshold) {
if ((batch_size * beam * seq_len * vocab_size) <= 0) if ((batch_size * beam * seq_len * vocab_size) <= 0) {
return {0, 0}; return {0, 0};
}
CHECK(prob_sizes.size() == 3, "only support 3D log_prob."); CHECK(prob_sizes.size() == 3, "only support 3D log_prob.");
CHECK(prob_strides.size() == 3, "only support 3D log_prob. "); CHECK(prob_strides.size() == 3, "only support 3D log_prob. ");
...@@ -123,8 +124,9 @@ std::tuple<size_t, int> calculate_require_buff_and_init_internal_data( ...@@ -123,8 +124,9 @@ std::tuple<size_t, int> calculate_require_buff_and_init_internal_data(
require_size += select_seq_lens_align_size; require_size += select_seq_lens_align_size;
require_size += ALIGN_BYTES; require_size += ALIGN_BYTES;
if (require_size > buff_size) if (require_size > buff_size) {
return {require_size, 0}; return {require_size, 0};
}
char* buff_align_ptr = reinterpret_cast<char*>(align_size(buff_ptr)); char* buff_align_ptr = reinterpret_cast<char*>(align_size(buff_ptr));
...@@ -291,7 +293,6 @@ void prefixCTC_free(std::uintptr_t inter_data_ptr) { ...@@ -291,7 +293,6 @@ void prefixCTC_free(std::uintptr_t inter_data_ptr) {
int ctc_beam_search_decoder_batch_gpu( int ctc_beam_search_decoder_batch_gpu(
InternalData* inter_data, InternalData* inter_data,
float* pp,
int blid, int blid,
int spid, int spid,
int* clist, int* clist,
......
...@@ -35,9 +35,9 @@ constexpr int MAX_BLOCKS = 800; ...@@ -35,9 +35,9 @@ constexpr int MAX_BLOCKS = 800;
template <typename T> template <typename T>
class DeviceDataWrap { class DeviceDataWrap {
public: public:
DeviceDataWrap() : data_{}, size_in_bytes_{} {}; DeviceDataWrap() : data_{}, size_in_bytes_{} {}
DeviceDataWrap(T* data_ptr, size_t size_in_byte) DeviceDataWrap(T* data_ptr, size_t size_in_byte)
: data_{data_ptr}, size_in_bytes_{size_in_byte} {}; : data_{data_ptr}, size_in_bytes_{size_in_byte} {}
void print(size_t offset, size_t size_in_element, int eles_per_row = 10) void print(size_t offset, size_t size_in_element, int eles_per_row = 10)
const { const {
if ((offset + size_in_element) * sizeof(T) > size_in_bytes_) { if ((offset + size_in_element) * sizeof(T) > size_in_bytes_) {
......
...@@ -70,13 +70,7 @@ ctc_prefix_decoder_batch_wrapper( ...@@ -70,13 +70,7 @@ ctc_prefix_decoder_batch_wrapper(
std::vector<int> len_data(batch_size * beam); std::vector<int> len_data(batch_size * beam);
std::vector<float> score(batch_size * beam); std::vector<float> score(batch_size * beam);
cu_ctc::ctc_beam_search_decoder_batch_gpu( cu_ctc::ctc_beam_search_decoder_batch_gpu(
inter_data, inter_data, blid, spid, list_data.data(), len_data.data(), score.data());
(float*)pp,
blid,
spid,
list_data.data(),
len_data.data(),
score.data());
SCORE_TYPE score_hyps{}; SCORE_TYPE score_hyps{};
score_hyps.reserve(batch_size); score_hyps.reserve(batch_size);
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
......
...@@ -116,11 +116,9 @@ void forced_align_impl( ...@@ -116,11 +116,9 @@ void forced_align_impl(
auto idx1 = (T - 1) % 2; auto idx1 = (T - 1) % 2;
auto ltrIdx = alphas_a[idx1][S - 1] > alphas_a[idx1][S - 2] ? S - 1 : S - 2; auto ltrIdx = alphas_a[idx1][S - 1] > alphas_a[idx1][S - 2] ? S - 1 : S - 2;
// path stores the token index for each time step after force alignment. // path stores the token index for each time step after force alignment.
auto indexScores = 0;
for (auto t = T - 1; t > -1; t--) { for (auto t = T - 1; t > -1; t--) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2]; auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2];
paths_a[batchIndex][t] = lbl_idx; paths_a[batchIndex][t] = lbl_idx;
++indexScores;
ltrIdx -= backPtr_a[t][ltrIdx]; ltrIdx -= backPtr_a[t][ltrIdx];
} }
} }
......
#pragma once
#include <torch/types.h> #include <torch/types.h>
#define EPS ((scalar_t)(1e-5)) #define EPS ((scalar_t)(1e-5))
...@@ -77,9 +78,9 @@ scalar_t cosine(const Wall<scalar_t>& wall, const torch::Tensor& dir) { ...@@ -77,9 +78,9 @@ scalar_t cosine(const Wall<scalar_t>& wall, const torch::Tensor& dir) {
/// 3D room /// 3D room
template <typename T> template <typename T>
const std::array<Wall<T>, 6> make_room( const std::array<Wall<T>, 6> make_room(
const T w, const T& w,
const T l, const T& l,
const T h, const T& h,
const torch::Tensor& abs, const torch::Tensor& abs,
const torch::Tensor& scat) { const torch::Tensor& scat) {
using namespace torch::indexing; using namespace torch::indexing;
......
...@@ -165,7 +165,6 @@ status_t ComputeLogProbs( ...@@ -165,7 +165,6 @@ status_t ComputeLogProbs(
template <typename DTYPE> template <typename DTYPE>
DTYPE ComputeAlphaOneSequence( DTYPE ComputeAlphaOneSequence(
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs, TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen, int srcLen,
int tgtLen, int tgtLen,
...@@ -198,7 +197,6 @@ DTYPE ComputeAlphaOneSequence( ...@@ -198,7 +197,6 @@ DTYPE ComputeAlphaOneSequence(
template <typename DTYPE> template <typename DTYPE>
DTYPE ComputeBetaOneSequence( DTYPE ComputeBetaOneSequence(
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs, TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen, int srcLen,
int tgtLen, int tgtLen,
...@@ -240,14 +238,12 @@ DTYPE ComputeAlphaOrBetaOneSequence( ...@@ -240,14 +238,12 @@ DTYPE ComputeAlphaOrBetaOneSequence(
TensorView<DTYPE>& beta) { TensorView<DTYPE>& beta) {
if (thread & 1) { if (thread & 1) {
return ComputeAlphaOneSequence<DTYPE>( return ComputeAlphaOneSequence<DTYPE>(
/*options=*/options,
/*logProbs=*/logProbs, /*logProbs=*/logProbs,
/*srcLen=*/srcLen, /*srcLen=*/srcLen,
/*tgtLen=*/tgtLen, /*tgtLen=*/tgtLen,
/*alpha=*/alpha); /*alpha=*/alpha);
} else { } else {
return ComputeBetaOneSequence<DTYPE>( return ComputeBetaOneSequence<DTYPE>(
/*options=*/options,
/*logProbs=*/logProbs, /*logProbs=*/logProbs,
/*srcLen=*/srcLen, /*srcLen=*/srcLen,
/*tgtLen=*/tgtLen, /*tgtLen=*/tgtLen,
...@@ -488,7 +484,6 @@ void ComputeAlphas( ...@@ -488,7 +484,6 @@ void ComputeAlphas(
//#pragma omp parallel for //#pragma omp parallel for
for (int i = 0; i < B; ++i) { // use max 2 * B threads. for (int i = 0; i < B; ++i) { // use max 2 * B threads.
ComputeAlphaOneSequence<DTYPE>( ComputeAlphaOneSequence<DTYPE>(
options,
/*logProbs=*/seqlogProbs[i], /*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i], /*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
...@@ -524,7 +519,6 @@ void ComputeBetas( ...@@ -524,7 +519,6 @@ void ComputeBetas(
//#pragma omp parallel for //#pragma omp parallel for
for (int i = 0; i < B; ++i) { for (int i = 0; i < B; ++i) {
ComputeBetaOneSequence<DTYPE>( ComputeBetaOneSequence<DTYPE>(
options,
/*logProbs=*/seqlogProbs[i], /*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i], /*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
......
...@@ -10,18 +10,20 @@ namespace math { ...@@ -10,18 +10,20 @@ namespace math {
template <typename DTYPE> template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) { FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) {
if (x > y) if (x > y) {
return x; return x;
else } else {
return y; return y;
}
} }
template <typename DTYPE> template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) { FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) {
if (x > y) if (x > y) {
return y; return y;
else } else {
return x; return x;
}
} }
// log_sum_exp // log_sum_exp
......
#include <libtorchaudio/rnnt/macros.h>
const char* ToString(level_t level) {
switch (level) {
case INFO:
return "INFO";
case WARNING:
return "WARNING";
case ERROR:
return "ERROR";
case FATAL:
return "FATAL";
default:
return "UNKNOWN";
}
}
...@@ -12,10 +12,3 @@ ...@@ -12,10 +12,3 @@
#define HOST_AND_DEVICE #define HOST_AND_DEVICE
#define FORCE_INLINE inline #define FORCE_INLINE inline
#endif // USE_CUDA #endif // USE_CUDA
#include <cstring>
#include <iostream>
typedef enum { INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 } level_t;
const char* ToString(level_t level);
#pragma once #pragma once
//#include <iostream>
#ifdef USE_CUDA #ifdef USE_CUDA
#include <cuda_runtime.h> #include <cuda_runtime.h>
#endif // USE_CUDA #endif // USE_CUDA
#include <libtorchaudio/rnnt/macros.h>
#include <libtorchaudio/rnnt/types.h> #include <libtorchaudio/rnnt/types.h>
#include <ostream>
namespace torchaudio { namespace torchaudio {
namespace rnnt { namespace rnnt {
typedef struct Options { struct Options {
// the device to compute transducer loss. // the device to compute transducer loss.
device_t device_; device_t device_;
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -78,7 +76,7 @@ typedef struct Options { ...@@ -78,7 +76,7 @@ typedef struct Options {
return os; return os;
} }
} Options; };
} // namespace rnnt } // namespace rnnt
} // namespace torchaudio } // namespace torchaudio
#include <libtorchaudio/rnnt/types.h>
namespace torchaudio {
namespace rnnt {
const char* toString(status_t status) {
switch (status) {
case SUCCESS:
return "success";
case FAILURE:
return "failure";
case COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED:
return "compute_denominator_reduce_max_failed";
case COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED:
return "compute_denominator_reduce_sum_failed";
case COMPUTE_LOG_PROBS_FAILED:
return "compute_log_probs_failed";
case COMPUTE_ALPHAS_BETAS_COSTS_FAILED:
return "compute_alphas_betas_costs_failed";
case COMPUTE_GRADIENTS_FAILED:
return "compute_gradients_failed";
default:
return "unknown";
}
}
const char* toString(device_t device) {
switch (device) {
case UNDEFINED:
return "undefined";
case CPU:
return "cpu";
case GPU:
return "gpu";
default:
return "unknown";
}
}
} // namespace rnnt
} // namespace torchaudio
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace torchaudio { namespace torchaudio {
namespace rnnt { namespace rnnt {
typedef enum { enum status_t {
SUCCESS = 0, SUCCESS = 0,
FAILURE = 1, FAILURE = 1,
COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED = 2, COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED = 2,
...@@ -11,13 +11,9 @@ typedef enum { ...@@ -11,13 +11,9 @@ typedef enum {
COMPUTE_LOG_PROBS_FAILED = 4, COMPUTE_LOG_PROBS_FAILED = 4,
COMPUTE_ALPHAS_BETAS_COSTS_FAILED = 5, COMPUTE_ALPHAS_BETAS_COSTS_FAILED = 5,
COMPUTE_GRADIENTS_FAILED = 6 COMPUTE_GRADIENTS_FAILED = 6
} status_t; };
typedef enum { UNDEFINED = 0, CPU = 1, GPU = 2 } device_t; enum device_t { UNDEFINED = 0, CPU = 1, GPU = 2 };
const char* toString(status_t status);
const char* toString(device_t device);
} // namespace rnnt } // namespace rnnt
} // namespace torchaudio } // namespace torchaudio
...@@ -129,14 +129,14 @@ int file_output_flow( ...@@ -129,14 +129,14 @@ int file_output_flow(
sox_effect_handler_t* get_tensor_input_handler() { sox_effect_handler_t* get_tensor_input_handler() {
static sox_effect_handler_t handler{ static sox_effect_handler_t handler{
/*name=*/"input_tensor", /*name=*/"input_tensor",
/*usage=*/NULL, /*usage=*/nullptr,
/*flags=*/SOX_EFF_MCHAN, /*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL, /*getopts=*/nullptr,
/*start=*/NULL, /*start=*/nullptr,
/*flow=*/NULL, /*flow=*/nullptr,
/*drain=*/tensor_input_drain, /*drain=*/tensor_input_drain,
/*stop=*/NULL, /*stop=*/nullptr,
/*kill=*/NULL, /*kill=*/nullptr,
/*priv_size=*/sizeof(TensorInputPriv)}; /*priv_size=*/sizeof(TensorInputPriv)};
return &handler; return &handler;
} }
...@@ -144,14 +144,14 @@ sox_effect_handler_t* get_tensor_input_handler() { ...@@ -144,14 +144,14 @@ sox_effect_handler_t* get_tensor_input_handler() {
sox_effect_handler_t* get_tensor_output_handler() { sox_effect_handler_t* get_tensor_output_handler() {
static sox_effect_handler_t handler{ static sox_effect_handler_t handler{
/*name=*/"output_tensor", /*name=*/"output_tensor",
/*usage=*/NULL, /*usage=*/nullptr,
/*flags=*/SOX_EFF_MCHAN, /*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL, /*getopts=*/nullptr,
/*start=*/NULL, /*start=*/nullptr,
/*flow=*/tensor_output_flow, /*flow=*/tensor_output_flow,
/*drain=*/NULL, /*drain=*/nullptr,
/*stop=*/NULL, /*stop=*/nullptr,
/*kill=*/NULL, /*kill=*/nullptr,
/*priv_size=*/sizeof(TensorOutputPriv)}; /*priv_size=*/sizeof(TensorOutputPriv)};
return &handler; return &handler;
} }
...@@ -159,14 +159,14 @@ sox_effect_handler_t* get_tensor_output_handler() { ...@@ -159,14 +159,14 @@ sox_effect_handler_t* get_tensor_output_handler() {
sox_effect_handler_t* get_file_output_handler() { sox_effect_handler_t* get_file_output_handler() {
static sox_effect_handler_t handler{ static sox_effect_handler_t handler{
/*name=*/"output_file", /*name=*/"output_file",
/*usage=*/NULL, /*usage=*/nullptr,
/*flags=*/SOX_EFF_MCHAN, /*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL, /*getopts=*/nullptr,
/*start=*/NULL, /*start=*/nullptr,
/*flow=*/file_output_flow, /*flow=*/file_output_flow,
/*drain=*/NULL, /*drain=*/nullptr,
/*stop=*/NULL, /*stop=*/nullptr,
/*kill=*/NULL, /*kill=*/nullptr,
/*priv_size=*/sizeof(FileOutputPriv)}; /*priv_size=*/sizeof(FileOutputPriv)};
return &handler; return &handler;
} }
...@@ -208,7 +208,7 @@ SoxEffectsChain::~SoxEffectsChain() { ...@@ -208,7 +208,7 @@ SoxEffectsChain::~SoxEffectsChain() {
} }
void SoxEffectsChain::run() { void SoxEffectsChain::run() {
sox_flow_effects(sec_, NULL, NULL); sox_flow_effects(sec_, nullptr, nullptr);
} }
void SoxEffectsChain::addInputTensor( void SoxEffectsChain::addInputTensor(
...@@ -259,7 +259,7 @@ void SoxEffectsChain::addOutputFile(sox_format_t* sf) { ...@@ -259,7 +259,7 @@ void SoxEffectsChain::addOutputFile(sox_format_t* sf) {
sf->filename); sf->filename);
} }
void SoxEffectsChain::addEffect(const std::vector<std::string> effect) { void SoxEffectsChain::addEffect(const std::vector<std::string>& effect) {
const auto num_args = effect.size(); const auto num_args = effect.size();
TORCH_CHECK(num_args != 0, "Invalid argument: empty effect."); TORCH_CHECK(num_args != 0, "Invalid argument: empty effect.");
const auto name = effect[0]; const auto name = effect[0];
......
...@@ -12,7 +12,7 @@ namespace torchaudio::sox { ...@@ -12,7 +12,7 @@ namespace torchaudio::sox {
struct SoxEffect { struct SoxEffect {
explicit SoxEffect(sox_effect_t* se) noexcept; explicit SoxEffect(sox_effect_t* se) noexcept;
SoxEffect(const SoxEffect& other) = delete; SoxEffect(const SoxEffect& other) = delete;
SoxEffect(const SoxEffect&& other) = delete; SoxEffect(SoxEffect&& other) = delete;
auto operator=(const SoxEffect& other) -> SoxEffect& = delete; auto operator=(const SoxEffect& other) -> SoxEffect& = delete;
auto operator=(SoxEffect&& other) -> SoxEffect& = delete; auto operator=(SoxEffect&& other) -> SoxEffect& = delete;
~SoxEffect(); ~SoxEffect();
...@@ -39,7 +39,7 @@ class SoxEffectsChain { ...@@ -39,7 +39,7 @@ class SoxEffectsChain {
sox_encodinginfo_t input_encoding, sox_encodinginfo_t input_encoding,
sox_encodinginfo_t output_encoding); sox_encodinginfo_t output_encoding);
SoxEffectsChain(const SoxEffectsChain& other) = delete; SoxEffectsChain(const SoxEffectsChain& other) = delete;
SoxEffectsChain(const SoxEffectsChain&& other) = delete; SoxEffectsChain(SoxEffectsChain&& other) = delete;
SoxEffectsChain& operator=(const SoxEffectsChain& other) = delete; SoxEffectsChain& operator=(const SoxEffectsChain& other) = delete;
SoxEffectsChain& operator=(SoxEffectsChain&& other) = delete; SoxEffectsChain& operator=(SoxEffectsChain&& other) = delete;
~SoxEffectsChain(); ~SoxEffectsChain();
...@@ -51,7 +51,7 @@ class SoxEffectsChain { ...@@ -51,7 +51,7 @@ class SoxEffectsChain {
void addInputFile(sox_format_t* sf); void addInputFile(sox_format_t* sf);
void addOutputBuffer(std::vector<sox_sample_t>* output_buffer); void addOutputBuffer(std::vector<sox_sample_t>* output_buffer);
void addOutputFile(sox_format_t* sf); void addOutputFile(sox_format_t* sf);
void addEffect(const std::vector<std::string> effect); void addEffect(const std::vector<std::string>& effect);
int64_t getOutputNumChannels(); int64_t getOutputNumChannels();
int64_t getOutputSampleRate(); int64_t getOutputSampleRate();
}; };
......
...@@ -78,8 +78,9 @@ void save_audio_file( ...@@ -78,8 +78,9 @@ void save_audio_file(
validate_input_tensor(tensor); validate_input_tensor(tensor);
const auto filetype = [&]() { const auto filetype = [&]() {
if (format.has_value()) if (format.has_value()) {
return format.value(); return format.value();
}
return get_filetype(path); return get_filetype(path);
}(); }();
......
...@@ -3,26 +3,36 @@ ...@@ -3,26 +3,36 @@
namespace torchaudio::sox { namespace torchaudio::sox {
Format get_format_from_string(const std::string& format) { Format get_format_from_string(const std::string& format) {
if (format == "wav") if (format == "wav") {
return Format::WAV; return Format::WAV;
if (format == "mp3") }
if (format == "mp3") {
return Format::MP3; return Format::MP3;
if (format == "flac") }
if (format == "flac") {
return Format::FLAC; return Format::FLAC;
if (format == "ogg" || format == "vorbis") }
if (format == "ogg" || format == "vorbis") {
return Format::VORBIS; return Format::VORBIS;
if (format == "amr-nb") }
if (format == "amr-nb") {
return Format::AMR_NB; return Format::AMR_NB;
if (format == "amr-wb") }
if (format == "amr-wb") {
return Format::AMR_WB; return Format::AMR_WB;
if (format == "amb") }
if (format == "amb") {
return Format::AMB; return Format::AMB;
if (format == "sph") }
if (format == "sph") {
return Format::SPHERE; return Format::SPHERE;
if (format == "htk") }
if (format == "htk") {
return Format::HTK; return Format::HTK;
if (format == "gsm") }
if (format == "gsm") {
return Format::GSM; return Format::GSM;
}
TORCH_CHECK(false, "Internal Error: unexpected format value: ", format); TORCH_CHECK(false, "Internal Error: unexpected format value: ", format);
} }
......
...@@ -5,6 +5,14 @@ ...@@ -5,6 +5,14 @@
namespace torchaudio::sox { namespace torchaudio::sox {
const std::unordered_set<std::string> UNSUPPORTED_EFFECTS{
"input",
"output",
"spectrogram",
"noiseprof",
"noisered",
"splice"};
void set_seed(const int64_t seed) { void set_seed(const int64_t seed) {
sox_get_globals()->ranqd1 = static_cast<sox_int32_t>(seed); sox_get_globals()->ranqd1 = static_cast<sox_int32_t>(seed);
} }
...@@ -46,8 +54,9 @@ std::vector<std::string> list_write_formats() { ...@@ -46,8 +54,9 @@ std::vector<std::string> list_write_formats() {
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) { for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
const sox_format_handler_t* handler = fns->fn(); const sox_format_handler_t* handler = fns->fn();
for (const char* const* names = handler->names; *names; ++names) { for (const char* const* names = handler->names; *names; ++names) {
if (!strchr(*names, '/') && handler->write) if (!strchr(*names, '/') && handler->write) {
formats.emplace_back(*names); formats.emplace_back(*names);
}
} }
} }
return formats; return formats;
...@@ -58,8 +67,9 @@ std::vector<std::string> list_read_formats() { ...@@ -58,8 +67,9 @@ std::vector<std::string> list_read_formats() {
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) { for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
const sox_format_handler_t* handler = fns->fn(); const sox_format_handler_t* handler = fns->fn();
for (const char* const* names = handler->names; *names; ++names) { for (const char* const* names = handler->names; *names; ++names) {
if (!strchr(*names, '/') && handler->read) if (!strchr(*names, '/') && handler->read) {
formats.emplace_back(*names); formats.emplace_back(*names);
}
} }
} }
return formats; return formats;
...@@ -193,7 +203,7 @@ const std::string get_filetype(const std::string& path) { ...@@ -193,7 +203,7 @@ const std::string get_filetype(const std::string& path) {
namespace { namespace {
std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav( std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
const std::string format, const std::string& format,
caffe2::TypeMeta dtype, caffe2::TypeMeta dtype,
const Encoding& encoding, const Encoding& encoding,
const BitDepth& bits_per_sample) { const BitDepth& bits_per_sample) {
...@@ -386,12 +396,15 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding( ...@@ -386,12 +396,15 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
} }
unsigned get_precision(const std::string& filetype, caffe2::TypeMeta dtype) { unsigned get_precision(const std::string& filetype, caffe2::TypeMeta dtype) {
if (filetype == "mp3") if (filetype == "mp3") {
return SOX_UNSPEC; return SOX_UNSPEC;
if (filetype == "flac") }
if (filetype == "flac") {
return 24; return 24;
if (filetype == "ogg" || filetype == "vorbis") }
if (filetype == "ogg" || filetype == "vorbis") {
return SOX_UNSPEC; return SOX_UNSPEC;
}
if (filetype == "wav" || filetype == "amb") { if (filetype == "wav" || filetype == "amb") {
switch (dtype.toScalarType()) { switch (dtype.toScalarType()) {
case c10::ScalarType::Byte: case c10::ScalarType::Byte:
...@@ -406,8 +419,9 @@ unsigned get_precision(const std::string& filetype, caffe2::TypeMeta dtype) { ...@@ -406,8 +419,9 @@ unsigned get_precision(const std::string& filetype, caffe2::TypeMeta dtype) {
TORCH_CHECK(false, "Unsupported dtype: ", dtype); TORCH_CHECK(false, "Unsupported dtype: ", dtype);
} }
} }
if (filetype == "sph") if (filetype == "sph") {
return 32; return 32;
}
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
return 16; return 16;
} }
...@@ -432,7 +446,8 @@ sox_signalinfo_t get_signalinfo( ...@@ -432,7 +446,8 @@ sox_signalinfo_t get_signalinfo(
/*channels=*/ /*channels=*/
static_cast<unsigned>(waveform->size(channels_first ? 0 : 1)), static_cast<unsigned>(waveform->size(channels_first ? 0 : 1)),
/*precision=*/get_precision(filetype, waveform->dtype()), /*precision=*/get_precision(filetype, waveform->dtype()),
/*length=*/static_cast<uint64_t>(waveform->numel())}; /*length=*/static_cast<uint64_t>(waveform->numel()),
nullptr};
} }
sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) { sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
......
...@@ -31,8 +31,7 @@ std::vector<std::string> list_write_formats(); ...@@ -31,8 +31,7 @@ std::vector<std::string> list_write_formats();
// Utilities for sox_io / sox_effects implementations // Utilities for sox_io / sox_effects implementations
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
const std::unordered_set<std::string> UNSUPPORTED_EFFECTS = extern const std::unordered_set<std::string> UNSUPPORTED_EFFECTS;
{"input", "output", "spectrogram", "noiseprof", "noisered", "splice"};
/// helper class to automatically close sox_format_t* /// helper class to automatically close sox_format_t*
struct SoxFormat { struct SoxFormat {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment