Commit 9dcc7a15 authored by flyingdown's avatar flyingdown
Browse files

init v0.10.0

parent db2b0b79
Pipeline #254 failed with stages
in 0 seconds
#pragma once
#ifdef USE_C10_HALF
#include "c10/util/Half.h"
#endif // USE_C10_HALF
#include <torchaudio/csrc/rnnt/macros.h>
namespace torchaudio {
namespace rnnt {
struct alignas(sizeof(__half)) Half {
__half x;
HOST_AND_DEVICE Half() = default;
FORCE_INLINE HOST_AND_DEVICE Half(float f) {
x = __float2half_rn(f);
if (isinf(__half2float(x))) {
x = __float2half_rz(f); // round toward 0.
}
}
FORCE_INLINE HOST_AND_DEVICE operator float() const {
return __half2float(x);
}
FORCE_INLINE HOST_AND_DEVICE Half(__half f) {
x = f;
}
FORCE_INLINE HOST_AND_DEVICE operator __half() const {
return x;
}
};
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace torchaudio {
namespace rnnt {
inline HOST_AND_DEVICE bool in_range(
int start,
int end, // inclusive
int val) {
return start <= val && val <= end;
}
#define LOG_PROBS_SKIP_IDX 0
#define LOG_PROBS_EMIT_IDX 1
struct Indexer2D {
const int& size2_;
FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {}
FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) {
return index1 * size2_ + index2;
}
};
struct Indexer3D {
const int& size2_;
const int& size3_;
FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3)
: size2_(size2), size3_(size3) {}
FORCE_INLINE HOST_AND_DEVICE int operator()(
int index1,
int index2,
int index3) {
return (index1 * size2_ + index2) * size3_ + index3;
}
};
struct Indexer4D {
const int& size2_;
const int& size3_;
const int& size4_;
HOST_AND_DEVICE Indexer4D(
const int& size2,
const int& size3,
const int& size4)
: size2_(size2), size3_(size3), size4_(size4) {}
HOST_AND_DEVICE int operator()(
int index1,
int index2,
int index3,
int index4) {
return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4;
}
};
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace torchaudio {
namespace rnnt {
template <typename DTYPE, typename CAST_DTYPE>
HOST_AND_DEVICE void ComputeGradientsElement(
int bTgt,
int t,
int u,
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
CAST_DTYPE clamp,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
if (t >= T || u >= U) { // out of boundary.
if (gradients == logits && t < maxT && u < maxU) {
// gradients and logits are pointing to the same memory location
Indexer3D idxr3(maxT, maxU);
int idx_b_t_u_zero = idxr3(bTgt, t, u);
if (idx_b_t_u_zero != -1) {
int start = idx_b_t_u_zero * D;
for (int b_t_u_d = start; b_t_u_d < start + D; ++b_t_u_d) {
gradients[b_t_u_d] = 0;
}
}
}
return;
}
int costIdx = bTgt * maxT * maxU;
CAST_DTYPE cost = -(betas[costIdx]);
Indexer2D idxr2(maxU - 1);
int idx_b_t_u, idx_b_t_up1, idx_b_tp1_u;
Indexer3D idxr3(maxT, maxU);
idx_b_t_u = idxr3(bTgt, t, u);
idx_b_t_up1 = idxr3(bTgt, t, u + 1);
idx_b_tp1_u = idxr3(bTgt, t + 1, u);
if (idx_b_t_u == -1) {
return;
}
if (isinf(cost) || isnan(cost)) {
for (int d = 0; d < D; ++d) {
int b_t_u_d = idx_b_t_u * D + d;
gradients[b_t_u_d] = 0;
}
return;
}
CAST_DTYPE c = alphas[idx_b_t_u] + cost - denominators[idx_b_t_u];
for (int d = 0; d < D; ++d) {
int b_t_u_d = idx_b_t_u * D + d;
CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c;
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
}
if (clamp > 0) {
auto g = CAST_DTYPE(gradients[b_t_u_d]);
gradients[b_t_u_d] = math::min(g, clamp);
gradients[b_t_u_d] = math::max(g, -clamp);
}
}
}
} // namespace rnnt
} // namespace torchaudio
#pragma once
#ifdef USE_CUDA
#include <cmath>
#endif // USE_CUDA
#include <torchaudio/csrc/rnnt/gpu/half.cuh>
namespace torchaudio {
namespace rnnt {
namespace math {
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) {
if (x > y)
return x;
else
return y;
}
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) {
if (x > y)
return y;
else
return x;
}
// log_sum_exp
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y);
template <>
FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) {
if (y > x) {
return y + log1pf(expf(x - y));
} else {
return x + log1pf(expf(y - x));
}
}
} // namespace math
} // namespace rnnt
} // namespace torchaudio
#include <torchaudio/csrc/rnnt/macros.h>
const char* ToString(level_t level) {
switch (level) {
case INFO:
return "INFO";
case WARNING:
return "WARNING";
case ERROR:
return "ERROR";
case FATAL:
return "FATAL";
default:
return "UNKNOWN";
}
}
#pragma once
#ifdef USE_CUDA
#define WARP_SIZE 32
#define MAX_THREADS_PER_BLOCK 1024
#define REDUCE_THREADS 256
#define HOST_AND_DEVICE __host__ __device__
#define FORCE_INLINE __forceinline__
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#else
#define HOST_AND_DEVICE
#define FORCE_INLINE inline
#endif // USE_CUDA
#include <cstring>
#include <iostream>
typedef enum { INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 } level_t;
const char* ToString(level_t level);
#pragma once
//#include <iostream>
#ifdef USE_CUDA
#include <cuda_runtime.h>
#endif // USE_CUDA
#include <torchaudio/csrc/rnnt/macros.h>
#include <torchaudio/csrc/rnnt/types.h>
namespace torchaudio {
namespace rnnt {
typedef struct Options {
// the device to compute transducer loss.
device_t device_;
#ifdef USE_CUDA
// the stream to launch kernels in when using GPU.
cudaStream_t stream_;
#endif
// The maximum number of threads that can be used.
int numThreads_;
// the index for "blank".
int blank_;
// whether to backtrack the best path.
bool backtrack_;
// gradient clamp value.
float clamp_;
// batch size = B.
int batchSize_;
// Number of hypos per sample = H
int nHypos_;
// the maximum length of src encodings = max_T.
int maxSrcLen_;
// the maximum length of tgt encodings = max_U.
int maxTgtLen_;
// num_targets = D.
int numTargets_;
Options()
: device_(UNDEFINED),
numThreads_(0),
blank_(-1),
backtrack_(false),
clamp_(-1), // negative for disabling clamping by default.
batchSize_(0),
nHypos_(1),
maxSrcLen_(0),
maxTgtLen_(0),
numTargets_(0) {}
int BU() const {
return batchSize_ * maxTgtLen_ * nHypos_;
}
int BTU() const {
return batchSize_ * maxSrcLen_ * maxTgtLen_ * nHypos_;
}
friend std::ostream& operator<<(std::ostream& os, const Options& options) {
os << "Options("
<< "batchSize_=" << options.batchSize_ << ", "
<< "maxSrcLen_=" << options.maxSrcLen_ << ", "
<< "maxTgtLen_=" << options.maxTgtLen_ << ", "
<< "numTargets_=" << options.numTargets_ << ")";
return os;
}
} Options;
} // namespace rnnt
} // namespace torchaudio
#include <torchaudio/csrc/rnnt/types.h>
namespace torchaudio {
namespace rnnt {
const char* toString(status_t status) {
switch (status) {
case SUCCESS:
return "success";
case FAILURE:
return "failure";
case COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED:
return "compute_denominator_reduce_max_failed";
case COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED:
return "compute_denominator_reduce_sum_failed";
case COMPUTE_LOG_PROBS_FAILED:
return "compute_log_probs_failed";
case COMPUTE_ALPHAS_BETAS_COSTS_FAILED:
return "compute_alphas_betas_costs_failed";
case COMPUTE_GRADIENTS_FAILED:
return "compute_gradients_failed";
default:
return "unknown";
}
}
const char* toString(device_t device) {
switch (device) {
case UNDEFINED:
return "undefined";
case CPU:
return "cpu";
case GPU:
return "gpu";
default:
return "unknown";
}
}
} // namespace rnnt
} // namespace torchaudio
#pragma once
namespace torchaudio {
namespace rnnt {
typedef enum {
SUCCESS = 0,
FAILURE = 1,
COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED = 2,
COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED = 3,
COMPUTE_LOG_PROBS_FAILED = 4,
COMPUTE_ALPHAS_BETAS_COSTS_FAILED = 5,
COMPUTE_GRADIENTS_FAILED = 6
} status_t;
typedef enum { UNDEFINED = 0, CPU = 1, GPU = 2 } device_t;
const char* toString(status_t status);
const char* toString(device_t device);
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <cstring>
#include <vector>
#include <torchaudio/csrc/rnnt/options.h>
namespace torchaudio {
namespace rnnt {
// Since CUDA has strict memory alignment, it's better to keep allocated memory
// blocks separate for different data types.
// DtypeWorkspace holds a "view" of workspace for:
// 1. softmax denominators (in log form), size = B * max_T * max_U
// 2. log probibility pairs for blank and target, size = B * max_T * max_U
// 3. alphas, size = B * max_T * max_U
// 4. betas, size = B * max_T * max_U
template <typename DTYPE>
class DtypeWorkspace {
public:
DtypeWorkspace() : options_(), size_(0), data_(nullptr) {}
DtypeWorkspace(const Options& options, DTYPE* data, int size)
: DtypeWorkspace() {
Reset(options, data, size);
}
~DtypeWorkspace() {}
static int ComputeSizeFromOptions(const Options& options) {
CHECK_NE(options.device_, UNDEFINED);
return ComputeSizeForDenominators(options) +
ComputeSizeForLogProbs(options) + ComputeSizeForAlphas(options) +
ComputeSizeForBetas(options);
}
void Free();
void Reset(const Options& options, DTYPE* data, int size) {
int needed_size = ComputeSizeFromOptions(options);
CHECK_LE(needed_size, size);
options_ = options;
data_ = data;
size_ = size;
}
int Size() const {
return size_;
}
DTYPE* GetPointerToDenominators() const {
return data_;
}
DTYPE* GetPointerToLogProbs() const {
return GetPointerToDenominators() + ComputeSizeForDenominators(options_);
}
DTYPE* GetPointerToAlphas() const {
return GetPointerToLogProbs() + ComputeSizeForLogProbs(options_);
}
DTYPE* GetPointerToBetas() const {
return GetPointerToAlphas() + ComputeSizeForAlphas(options_);
}
private:
static int ComputeSizeForDenominators(const Options& options) { // B * T * U
return options.BTU();
}
static int ComputeSizeForLogProbs(const Options& options) { // B * T * U * 2
return options.BTU() * 2;
}
static int ComputeSizeForAlphas(const Options& options) { // B * T * U
return options.BTU();
}
static int ComputeSizeForBetas(const Options& options) { // B * T * U
return options.BTU();
}
Options options_;
int size_; // number of elements in allocated memory.
DTYPE* data_; // pointer to the allocated memory.
};
// IntWorkspace holds a "view" of workspace for:
// 1. alpha counters, size = B * max_U
// 2. beta counters, size = B * max_U
class IntWorkspace {
public:
IntWorkspace() : options_(), size_(0), data_(nullptr) {}
IntWorkspace(const Options& options, int* data, int size) : IntWorkspace() {
Reset(options, data, size);
}
~IntWorkspace() {}
static int ComputeSizeFromOptions(const Options& options) {
return ComputeSizeForAlphaCounters(options) +
ComputeSizeForBetaCounters(options);
}
void Reset(const Options& options, int* data, int size) {
int needed_size = ComputeSizeFromOptions(options);
CHECK_LE(needed_size, size);
options_ = options;
data_ = data;
size_ = size;
ResetAlphaBetaCounters();
}
int Size() const {
return size_;
}
int* GetPointerToAlphaCounters() const {
CHECK_EQ(options_.device_, GPU);
return data_;
}
int* GetPointerToBetaCounters() const {
CHECK_EQ(options_.device_, GPU);
return GetPointerToAlphaCounters() + ComputeSizeForAlphaCounters(options_);
}
private:
inline void ResetAlphaBetaCounters() {
#ifdef USE_CUDA
if (data_ != nullptr && options_.device_ == GPU) {
cudaMemset(
GetPointerToAlphaCounters(),
0,
ComputeSizeForAlphaCounters(options_) * sizeof(int));
cudaMemset(
GetPointerToBetaCounters(),
0,
ComputeSizeForBetaCounters(options_) * sizeof(int));
}
#endif // USE_CUDA
}
static int ComputeSizeForAlphaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
if (options.device_ == GPU) {
return options.BU();
} else {
return 0;
}
#else
return 0;
#endif // USE_CUDA
}
static int ComputeSizeForBetaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
if (options.device_ == GPU) {
return options.BU();
} else {
return 0;
}
#else
return 0;
#endif // USE_CUDA
}
Options options_;
int size_; // number of elements in allocated memory.
int* data_; // pointer to the allocated memory.
};
// Workspace<DTYPE> holds:
// 1. DtypeWorkspace<DTYPE>
// 2. IntWorkspace
template <typename DTYPE>
class Workspace {
public:
Workspace() : options_(), dtype_workspace_(), int_workspace_() {}
Workspace(
const Options& options,
DTYPE* dtype_data,
int dtype_size,
int* int_data,
int int_size)
: Workspace() {
Reset(options, dtype_data, dtype_size, int_data, int_size);
}
~Workspace() {}
void Reset(
const Options& options,
DTYPE* dtype_data,
int dtype_size,
int* int_data,
int int_size) {
options_ = options;
dtype_workspace_.Reset(options_, dtype_data, dtype_size);
int_workspace_.Reset(options_, int_data, int_size);
}
const Options& GetOptions() const {
return options_;
}
DTYPE* GetPointerToDenominators() const {
return dtype_workspace_.GetPointerToDenominators();
}
DTYPE* GetPointerToLogProbs() const {
return dtype_workspace_.GetPointerToLogProbs();
}
DTYPE* GetPointerToAlphas() const {
return dtype_workspace_.GetPointerToAlphas();
}
DTYPE* GetPointerToBetas() const {
return dtype_workspace_.GetPointerToBetas();
}
int* GetPointerToAlphaCounters() const {
return int_workspace_.GetPointerToAlphaCounters();
}
int* GetPointerToBetaCounters() const {
return int_workspace_.GetPointerToBetaCounters();
}
private:
Options options_;
DtypeWorkspace<DTYPE> dtype_workspace_;
IntWorkspace int_workspace_;
};
} // namespace rnnt
} // namespace torchaudio
#include <sox.h>
#include <torchaudio/csrc/sox/effects.h>
#include <torchaudio/csrc/sox/effects_chain.h>
#include <torchaudio/csrc/sox/utils.h>
using namespace torchaudio::sox_utils;
namespace torchaudio::sox_effects {
namespace {
enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown };
SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized;
std::mutex SOX_RESOUCE_STATE_MUTEX;
} // namespace
void initialize_sox_effects() {
const std::lock_guard<std::mutex> lock(SOX_RESOUCE_STATE_MUTEX);
switch (SOX_RESOURCE_STATE) {
case NotInitialized:
if (sox_init() != SOX_SUCCESS) {
throw std::runtime_error("Failed to initialize sox effects.");
};
SOX_RESOURCE_STATE = Initialized;
break;
case Initialized:
break;
case ShutDown:
throw std::runtime_error(
"SoX Effects has been shut down. Cannot initialize again.");
}
};
void shutdown_sox_effects() {
const std::lock_guard<std::mutex> lock(SOX_RESOUCE_STATE_MUTEX);
switch (SOX_RESOURCE_STATE) {
case NotInitialized:
throw std::runtime_error(
"SoX Effects is not initialized. Cannot shutdown.");
case Initialized:
if (sox_quit() != SOX_SUCCESS) {
throw std::runtime_error("Failed to initialize sox effects.");
};
SOX_RESOURCE_STATE = ShutDown;
break;
case ShutDown:
break;
}
}
auto apply_effects_tensor(
torch::Tensor waveform,
int64_t sample_rate,
const std::vector<std::vector<std::string>>& effects,
bool channels_first) -> std::tuple<torch::Tensor, int64_t> {
validate_input_tensor(waveform);
// Create SoxEffectsChain
const auto dtype = waveform.dtype();
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(dtype),
/*output_encoding=*/get_tensor_encodinginfo(dtype));
// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
out_buffer.reserve(waveform.numel());
// Build and run effects chain
chain.addInputTensor(&waveform, sample_rate, channels_first);
for (const auto& effect : effects) {
chain.addEffect(effect);
}
chain.addOutputBuffer(&out_buffer);
chain.run();
// Create tensor from buffer
auto out_tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(),
dtype,
/*normalize=*/false,
channels_first);
return std::tuple<torch::Tensor, int64_t>(
out_tensor, chain.getOutputSampleRate());
}
auto apply_effects_file(
const std::string& path,
const std::vector<std::vector<std::string>>& effects,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
const c10::optional<std::string>& format)
-> std::tuple<torch::Tensor, int64_t> {
// Open input file
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf, path);
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
// Prepare output
std::vector<sox_sample_t> out_buffer;
out_buffer.reserve(sf->signal.length);
// Create and run SoxEffectsChain
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_tensor_encodinginfo(dtype));
chain.addInputFile(sf);
for (const auto& effect : effects) {
chain.addEffect(effect);
}
chain.addOutputBuffer(&out_buffer);
chain.run();
// Create tensor from buffer
bool channels_first_ = channels_first.value_or(true);
auto tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(),
dtype,
normalize.value_or(true),
channels_first_);
return std::tuple<torch::Tensor, int64_t>(
tensor, chain.getOutputSampleRate());
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::sox_effects_initialize_sox_effects",
&torchaudio::sox_effects::initialize_sox_effects);
m.def(
"torchaudio::sox_effects_shutdown_sox_effects",
&torchaudio::sox_effects::shutdown_sox_effects);
m.def(
"torchaudio::sox_effects_apply_effects_tensor",
&torchaudio::sox_effects::apply_effects_tensor);
m.def(
"torchaudio::sox_effects_apply_effects_file",
&torchaudio::sox_effects::apply_effects_file);
}
} // namespace torchaudio::sox_effects
#ifndef TORCHAUDIO_SOX_EFFECTS_H
#define TORCHAUDIO_SOX_EFFECTS_H
#include <torch/script.h>
#include <torchaudio/csrc/sox/utils.h>
namespace torchaudio::sox_effects {
void initialize_sox_effects();
void shutdown_sox_effects();
auto apply_effects_tensor(
torch::Tensor waveform,
int64_t sample_rate,
const std::vector<std::vector<std::string>>& effects,
bool channels_first) -> std::tuple<torch::Tensor, int64_t>;
auto apply_effects_file(
const std::string& path,
const std::vector<std::vector<std::string>>& effects,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
const c10::optional<std::string>& format)
-> std::tuple<torch::Tensor, int64_t>;
} // namespace torchaudio::sox_effects
#endif
#include <torchaudio/csrc/sox/effects_chain.h>
#include <torchaudio/csrc/sox/utils.h>
using namespace torch::indexing;
using namespace torchaudio::sox_utils;
namespace torchaudio {
namespace sox_effects_chain {
namespace {
/// helper classes for passing the location of input tensor and output buffer
///
/// drain/flow callback functions require plaing C style function signature and
/// the way to pass extra data is to attach data to sox_effect_t::priv pointer.
/// The following structs will be assigned to sox_effect_t::priv pointer which
/// gives sox_effect_t an access to input Tensor and output buffer object.
struct TensorInputPriv {
size_t index;
torch::Tensor* waveform;
int64_t sample_rate;
bool channels_first;
};
struct TensorOutputPriv {
std::vector<sox_sample_t>* buffer;
};
struct FileOutputPriv {
sox_format_t* sf;
};
/// Callback function to feed Tensor data to SoxEffectChain.
int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Retrieve the input Tensor and current index
auto priv = static_cast<TensorInputPriv*>(effp->priv);
auto index = priv->index;
auto tensor = *(priv->waveform);
auto num_channels = effp->out_signal.channels;
// Adjust the number of samples to read
const size_t num_samples = tensor.numel();
if (index + *osamp > num_samples) {
*osamp = num_samples - index;
}
// Ensure that it's a multiple of the number of channels
*osamp -= *osamp % num_channels;
// Slice the input Tensor
auto chunk = [&]() {
auto i_frame = index / num_channels;
auto num_frames = *osamp / num_channels;
auto t = (priv->channels_first)
? tensor.index({Slice(), Slice(i_frame, i_frame + num_frames)}).t()
: tensor.index({Slice(i_frame, i_frame + num_frames), Slice()});
return t.reshape({-1});
}();
// Convert to sox_sample_t (int32_t)
switch (chunk.dtype().toScalarType()) {
case c10::ScalarType::Float: {
// Need to convert to 64-bit precision so that
// values around INT32_MIN/MAX are handled correctly.
chunk = chunk.to(c10::ScalarType::Double);
chunk *= 2147483648.;
chunk.clamp_(INT32_MIN, INT32_MAX);
chunk = chunk.to(c10::ScalarType::Int);
break;
}
case c10::ScalarType::Int: {
break;
}
case c10::ScalarType::Short: {
chunk = chunk.to(c10::ScalarType::Int);
chunk *= 65536;
break;
}
case c10::ScalarType::Byte: {
chunk = chunk.to(c10::ScalarType::Int);
chunk -= 128;
chunk *= 16777216;
break;
}
default:
throw std::runtime_error("Unexpected dtype.");
}
// Write to buffer
chunk = chunk.contiguous();
memcpy(obuf, chunk.data_ptr<int32_t>(), *osamp * 4);
priv->index += *osamp;
return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS;
}
/// Callback function to fetch data from SoxEffectChain.
int tensor_output_flow(
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
size_t* osamp) {
*osamp = 0;
// Get output buffer
auto out_buffer = static_cast<TensorOutputPriv*>(effp->priv)->buffer;
// Append at the end
out_buffer->insert(out_buffer->end(), ibuf, ibuf + *isamp);
return SOX_SUCCESS;
}
int file_output_flow(
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
size_t* osamp) {
*osamp = 0;
if (*isamp) {
auto sf = static_cast<FileOutputPriv*>(effp->priv)->sf;
if (sox_write(sf, ibuf, *isamp) != *isamp) {
if (sf->sox_errno) {
std::ostringstream stream;
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " "
<< sf->filename;
throw std::runtime_error(stream.str());
}
return SOX_EOF;
}
}
return SOX_SUCCESS;
}
sox_effect_handler_t* get_tensor_input_handler() {
static sox_effect_handler_t handler{
/*name=*/"input_tensor",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/NULL,
/*drain=*/tensor_input_drain,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(TensorInputPriv)};
return &handler;
}
sox_effect_handler_t* get_tensor_output_handler() {
static sox_effect_handler_t handler{
/*name=*/"output_tensor",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/tensor_output_flow,
/*drain=*/NULL,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(TensorOutputPriv)};
return &handler;
}
sox_effect_handler_t* get_file_output_handler() {
static sox_effect_handler_t handler{
/*name=*/"output_file",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/file_output_flow,
/*drain=*/NULL,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(FileOutputPriv)};
return &handler;
}
} // namespace
SoxEffect::SoxEffect(sox_effect_t* se) noexcept : se_(se) {}
SoxEffect::~SoxEffect() {
if (se_ != nullptr) {
free(se_);
}
}
SoxEffect::operator sox_effect_t*() const {
return se_;
}
auto SoxEffect::operator->() noexcept -> sox_effect_t* {
return se_;
}
SoxEffectsChain::SoxEffectsChain(
sox_encodinginfo_t input_encoding,
sox_encodinginfo_t output_encoding)
: in_enc_(input_encoding),
out_enc_(output_encoding),
in_sig_(),
interm_sig_(),
out_sig_(),
sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) {
if (!sec_) {
throw std::runtime_error("Failed to create effect chain.");
}
}
SoxEffectsChain::~SoxEffectsChain() {
if (sec_ != nullptr) {
sox_delete_effects_chain(sec_);
}
}
void SoxEffectsChain::run() {
sox_flow_effects(sec_, NULL, NULL);
}
void SoxEffectsChain::addInputTensor(
torch::Tensor* waveform,
int64_t sample_rate,
bool channels_first) {
in_sig_ = get_signalinfo(waveform, sample_rate, "wav", channels_first);
interm_sig_ = in_sig_;
SoxEffect e(sox_create_effect(get_tensor_input_handler()));
auto priv = static_cast<TensorInputPriv*>(e->priv);
priv->index = 0;
priv->waveform = waveform;
priv->sample_rate = sample_rate;
priv->channels_first = channels_first;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: input_tensor");
}
}
void SoxEffectsChain::addOutputBuffer(
std::vector<sox_sample_t>* output_buffer) {
SoxEffect e(sox_create_effect(get_tensor_output_handler()));
static_cast<TensorOutputPriv*>(e->priv)->buffer = output_buffer;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: output_tensor");
}
}
void SoxEffectsChain::addInputFile(sox_format_t* sf) {
in_sig_ = sf->signal;
interm_sig_ = in_sig_;
SoxEffect e(sox_create_effect(sox_find_effect("input")));
char* opts[] = {(char*)sf};
sox_effect_options(e, 1, opts);
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Internal Error: Failed to add effect: input " << sf->filename;
throw std::runtime_error(stream.str());
}
}
void SoxEffectsChain::addOutputFile(sox_format_t* sf) {
out_sig_ = sf->signal;
SoxEffect e(sox_create_effect(get_file_output_handler()));
static_cast<FileOutputPriv*>(e->priv)->sf = sf;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Internal Error: Failed to add effect: output " << sf->filename;
throw std::runtime_error(stream.str());
}
}
void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
const auto num_args = effect.size();
if (num_args == 0) {
throw std::runtime_error("Invalid argument: empty effect.");
}
const auto name = effect[0];
if (UNSUPPORTED_EFFECTS.find(name) != UNSUPPORTED_EFFECTS.end()) {
std::ostringstream stream;
stream << "Unsupported effect: " << name;
throw std::runtime_error(stream.str());
}
auto returned_effect = sox_find_effect(name.c_str());
if (!returned_effect) {
std::ostringstream stream;
stream << "Unsupported effect: " << name;
throw std::runtime_error(stream.str());
}
SoxEffect e(sox_create_effect(returned_effect));
const auto num_options = num_args - 1;
std::vector<char*> opts;
for (size_t i = 1; i < num_args; ++i) {
opts.push_back((char*)effect[i].c_str());
}
if (sox_effect_options(e, num_options, num_options ? opts.data() : nullptr) !=
SOX_SUCCESS) {
std::ostringstream stream;
stream << "Invalid effect option:";
for (const auto& v : effect) {
stream << " " << v;
}
throw std::runtime_error(stream.str());
}
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Internal Error: Failed to add effect: \"" << name;
for (size_t i = 1; i < num_args; ++i) {
stream << " " << effect[i];
}
stream << "\"";
throw std::runtime_error(stream.str());
}
}
int64_t SoxEffectsChain::getOutputNumChannels() {
return interm_sig_.channels;
}
int64_t SoxEffectsChain::getOutputSampleRate() {
return interm_sig_.rate;
}
} // namespace sox_effects_chain
} // namespace torchaudio
#ifndef TORCHAUDIO_SOX_EFFECTS_CHAIN_H
#define TORCHAUDIO_SOX_EFFECTS_CHAIN_H
#include <sox.h>
#include <torchaudio/csrc/sox/utils.h>
namespace torchaudio {
namespace sox_effects_chain {
// Helper struct to safely close sox_effect_t* pointer returned by
// sox_create_effect
struct SoxEffect {
explicit SoxEffect(sox_effect_t* se) noexcept;
SoxEffect(const SoxEffect& other) = delete;
SoxEffect(const SoxEffect&& other) = delete;
auto operator=(const SoxEffect& other) -> SoxEffect& = delete;
auto operator=(SoxEffect&& other) -> SoxEffect& = delete;
~SoxEffect();
operator sox_effect_t*() const;
auto operator->() noexcept -> sox_effect_t*;
private:
sox_effect_t* se_;
};
// Helper struct to safely close sox_effects_chain_t with handy methods
class SoxEffectsChain {
const sox_encodinginfo_t in_enc_;
const sox_encodinginfo_t out_enc_;
protected:
sox_signalinfo_t in_sig_;
sox_signalinfo_t interm_sig_;
sox_signalinfo_t out_sig_;
sox_effects_chain_t* sec_;
public:
explicit SoxEffectsChain(
sox_encodinginfo_t input_encoding,
sox_encodinginfo_t output_encoding);
SoxEffectsChain(const SoxEffectsChain& other) = delete;
SoxEffectsChain(const SoxEffectsChain&& other) = delete;
SoxEffectsChain& operator=(const SoxEffectsChain& other) = delete;
SoxEffectsChain& operator=(SoxEffectsChain&& other) = delete;
~SoxEffectsChain();
void run();
void addInputTensor(
torch::Tensor* waveform,
int64_t sample_rate,
bool channels_first);
void addInputFile(sox_format_t* sf);
void addOutputBuffer(std::vector<sox_sample_t>* output_buffer);
void addOutputFile(sox_format_t* sf);
void addEffect(const std::vector<std::string> effect);
int64_t getOutputNumChannels();
int64_t getOutputSampleRate();
};
} // namespace sox_effects_chain
} // namespace torchaudio
#endif
#include <torchaudio/csrc/sox/effects.h>
#include <torchaudio/csrc/sox/effects_chain.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/types.h>
#include <torchaudio/csrc/sox/utils.h>
using namespace torch::indexing;
using namespace torchaudio::sox_utils;
namespace torchaudio {
namespace sox_io {
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
const std::string& path,
const c10::optional<std::string>& format) {
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf, path);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
std::vector<std::vector<std::string>> get_effects(
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames) {
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
"Invalid argument: frame_offset must be non-negative.");
}
const auto frames = num_frames.value_or(-1);
if (frames == 0 || frames < -1) {
throw std::runtime_error(
"Invalid argument: num_frames must be -1 or greater than 0.");
}
std::vector<std::vector<std::string>> effects;
if (frames != -1) {
std::ostringstream os_offset, os_frames;
os_offset << offset << "s";
os_frames << "+" << frames << "s";
effects.emplace_back(
std::vector<std::string>{"trim", os_offset.str(), os_frames.str()});
} else if (offset != 0) {
std::ostringstream os_offset;
os_offset << offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
}
return effects;
}
std::tuple<torch::Tensor, int64_t> load_audio_file(
const std::string& path,
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
const c10::optional<std::string>& format) {
auto effects = get_effects(frame_offset, num_frames);
return torchaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first, format);
}
void save_audio_file(
const std::string& path,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format,
c10::optional<std::string> encoding,
c10::optional<int64_t> bits_per_sample) {
validate_input_tensor(tensor);
const auto filetype = [&]() {
if (format.has_value())
return format.value();
return get_filetype(path);
}();
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "amr-nb format only supports single channel audio.");
} else if (filetype == "htk") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "htk format only supports single channel audio.");
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "gsm format only supports single channel audio.");
TORCH_CHECK(
sample_rate == 8000,
"gsm format only supports a sampling rate of 8kHz.");
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo_for_save(
filetype, tensor.dtype(), compression, encoding, bits_per_sample);
SoxFormat sf(sox_open_write(
path.c_str(),
&signal_info,
&encoding_info,
/*filetype=*/filetype.c_str(),
/*oob=*/nullptr,
/*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open file " + path);
}
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFile(sf);
chain.run();
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def(
"torchaudio::sox_io_load_audio_file",
&torchaudio::sox_io::load_audio_file);
m.def(
"torchaudio::sox_io_save_audio_file",
&torchaudio::sox_io::save_audio_file);
}
} // namespace sox_io
} // namespace torchaudio
#ifndef TORCHAUDIO_SOX_IO_H
#define TORCHAUDIO_SOX_IO_H
#include <torch/script.h>
#include <torchaudio/csrc/sox/utils.h>
namespace torchaudio {
namespace sox_io {
auto get_effects(
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames)
-> std::vector<std::vector<std::string>>;
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
const std::string& path,
const c10::optional<std::string>& format);
std::tuple<torch::Tensor, int64_t> load_audio_file(
const std::string& path,
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
const c10::optional<std::string>& format);
void save_audio_file(
const std::string& path,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format,
c10::optional<std::string> encoding,
c10::optional<int64_t> bits_per_sample);
} // namespace sox_io
} // namespace torchaudio
#endif
#include <torchaudio/csrc/sox/types.h>
namespace torchaudio {
namespace sox_utils {
Format get_format_from_string(const std::string& format) {
if (format == "wav")
return Format::WAV;
if (format == "mp3")
return Format::MP3;
if (format == "flac")
return Format::FLAC;
if (format == "ogg" || format == "vorbis")
return Format::VORBIS;
if (format == "amr-nb")
return Format::AMR_NB;
if (format == "amr-wb")
return Format::AMR_WB;
if (format == "amb")
return Format::AMB;
if (format == "sph")
return Format::SPHERE;
if (format == "htk")
return Format::HTK;
if (format == "gsm")
return Format::GSM;
std::ostringstream stream;
stream << "Internal Error: unexpected format value: " << format;
throw std::runtime_error(stream.str());
}
std::string to_string(Encoding v) {
switch (v) {
case Encoding::UNKNOWN:
return "UNKNOWN";
case Encoding::PCM_SIGNED:
return "PCM_S";
case Encoding::PCM_UNSIGNED:
return "PCM_U";
case Encoding::PCM_FLOAT:
return "PCM_F";
case Encoding::FLAC:
return "FLAC";
case Encoding::ULAW:
return "ULAW";
case Encoding::ALAW:
return "ALAW";
case Encoding::MP3:
return "MP3";
case Encoding::VORBIS:
return "VORBIS";
case Encoding::AMR_WB:
return "AMR_WB";
case Encoding::AMR_NB:
return "AMR_NB";
case Encoding::OPUS:
return "OPUS";
default:
throw std::runtime_error("Internal Error: unexpected encoding.");
}
}
Encoding get_encoding_from_option(const c10::optional<std::string> encoding) {
if (!encoding.has_value())
return Encoding::NOT_PROVIDED;
std::string v = encoding.value();
if (v == "PCM_S")
return Encoding::PCM_SIGNED;
if (v == "PCM_U")
return Encoding::PCM_UNSIGNED;
if (v == "PCM_F")
return Encoding::PCM_FLOAT;
if (v == "ULAW")
return Encoding::ULAW;
if (v == "ALAW")
return Encoding::ALAW;
std::ostringstream stream;
stream << "Internal Error: unexpected encoding value: " << v;
throw std::runtime_error(stream.str());
}
BitDepth get_bit_depth_from_option(const c10::optional<int64_t> bit_depth) {
if (!bit_depth.has_value())
return BitDepth::NOT_PROVIDED;
int64_t v = bit_depth.value();
switch (v) {
case 8:
return BitDepth::B8;
case 16:
return BitDepth::B16;
case 24:
return BitDepth::B24;
case 32:
return BitDepth::B32;
case 64:
return BitDepth::B64;
default: {
std::ostringstream s;
s << "Internal Error: unexpected bit depth value: " << v;
throw std::runtime_error(s.str());
}
}
}
std::string get_encoding(sox_encoding_t encoding) {
switch (encoding) {
case SOX_ENCODING_UNKNOWN:
return "UNKNOWN";
case SOX_ENCODING_SIGN2:
return "PCM_S";
case SOX_ENCODING_UNSIGNED:
return "PCM_U";
case SOX_ENCODING_FLOAT:
return "PCM_F";
case SOX_ENCODING_FLAC:
return "FLAC";
case SOX_ENCODING_ULAW:
return "ULAW";
case SOX_ENCODING_ALAW:
return "ALAW";
case SOX_ENCODING_MP3:
return "MP3";
case SOX_ENCODING_VORBIS:
return "VORBIS";
case SOX_ENCODING_AMR_WB:
return "AMR_WB";
case SOX_ENCODING_AMR_NB:
return "AMR_NB";
case SOX_ENCODING_OPUS:
return "OPUS";
case SOX_ENCODING_GSM:
return "GSM";
default:
return "UNKNOWN";
}
}
} // namespace sox_utils
} // namespace torchaudio
#ifndef TORCHAUDIO_SOX_TYPES_H
#define TORCHAUDIO_SOX_TYPES_H
#include <sox.h>
#include <torch/script.h>
namespace torchaudio {
namespace sox_utils {
enum class Format {
WAV,
MP3,
FLAC,
VORBIS,
AMR_NB,
AMR_WB,
AMB,
SPHERE,
GSM,
HTK,
};
Format get_format_from_string(const std::string& format);
enum class Encoding {
NOT_PROVIDED,
UNKNOWN,
PCM_SIGNED,
PCM_UNSIGNED,
PCM_FLOAT,
FLAC,
ULAW,
ALAW,
MP3,
VORBIS,
AMR_WB,
AMR_NB,
OPUS,
};
std::string to_string(Encoding v);
Encoding get_encoding_from_option(const c10::optional<std::string> encoding);
enum class BitDepth : unsigned {
NOT_PROVIDED = 0,
B8 = 8,
B16 = 16,
B24 = 24,
B32 = 32,
B64 = 64,
};
BitDepth get_bit_depth_from_option(const c10::optional<int64_t> bit_depth);
std::string get_encoding(sox_encoding_t encoding);
} // namespace sox_utils
} // namespace torchaudio
#endif
#include <c10/core/ScalarType.h>
#include <sox.h>
#include <torchaudio/csrc/sox/types.h>
#include <torchaudio/csrc/sox/utils.h>
namespace torchaudio {
namespace sox_utils {
void set_seed(const int64_t seed) {
sox_get_globals()->ranqd1 = static_cast<sox_int32_t>(seed);
}
void set_verbosity(const int64_t verbosity) {
sox_get_globals()->verbosity = static_cast<unsigned>(verbosity);
}
void set_use_threads(const bool use_threads) {
sox_get_globals()->use_threads = static_cast<sox_bool>(use_threads);
}
void set_buffer_size(const int64_t buffer_size) {
sox_get_globals()->bufsiz = static_cast<size_t>(buffer_size);
}
int64_t get_buffer_size() {
return sox_get_globals()->bufsiz;
}
std::vector<std::vector<std::string>> list_effects() {
std::vector<std::vector<std::string>> effects;
for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) {
const sox_effect_handler_t* handler = (*fns)();
if (handler && handler->name) {
if (UNSUPPORTED_EFFECTS.find(handler->name) ==
UNSUPPORTED_EFFECTS.end()) {
effects.emplace_back(std::vector<std::string>{
handler->name,
handler->usage ? std::string(handler->usage) : std::string("")});
}
}
}
return effects;
}
std::vector<std::string> list_write_formats() {
std::vector<std::string> formats;
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
const sox_format_handler_t* handler = fns->fn();
for (const char* const* names = handler->names; *names; ++names) {
if (!strchr(*names, '/') && handler->write)
formats.emplace_back(*names);
}
}
return formats;
}
std::vector<std::string> list_read_formats() {
std::vector<std::string> formats;
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
const sox_format_handler_t* handler = fns->fn();
for (const char* const* names = handler->names; *names; ++names) {
if (!strchr(*names, '/') && handler->read)
formats.emplace_back(*names);
}
}
return formats;
}
SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
SoxFormat::~SoxFormat() {
close();
}
sox_format_t* SoxFormat::operator->() const noexcept {
return fd_;
}
SoxFormat::operator sox_format_t*() const noexcept {
return fd_;
}
void SoxFormat::close() {
if (fd_ != nullptr) {
sox_close(fd_);
fd_ = nullptr;
}
}
void validate_input_file(const SoxFormat& sf, const std::string& path) {
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error loading audio file: failed to open file " + path);
}
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}
}
void validate_input_memfile(const SoxFormat& sf) {
return validate_input_file(sf, "<in memory buffer>");
}
void validate_input_tensor(const torch::Tensor tensor) {
if (!tensor.device().is_cpu()) {
throw std::runtime_error("Input tensor has to be on CPU.");
}
if (tensor.ndimension() != 2) {
throw std::runtime_error("Input tensor has to be 2D.");
}
switch (tensor.dtype().toScalarType()) {
case c10::ScalarType::Byte:
case c10::ScalarType::Short:
case c10::ScalarType::Int:
case c10::ScalarType::Float:
break;
default:
throw std::runtime_error(
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
}
}
caffe2::TypeMeta get_dtype(
const sox_encoding_t encoding,
const unsigned precision) {
const auto dtype = [&]() {
switch (encoding) {
case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV
return torch::kUInt8;
case SOX_ENCODING_SIGN2: // 16-bit, 24-bit, or 32-bit PCM WAV
switch (precision) {
case 16:
return torch::kInt16;
case 24: // Cast 24-bit to 32-bit.
case 32:
return torch::kInt32;
default:
throw std::runtime_error(
"Only 16, 24, and 32 bits are supported for signed PCM.");
}
default:
// default to float32 for the other formats, including
// 32-bit flaoting-point WAV,
// MP3,
// FLAC,
// VORBIS etc...
return torch::kFloat32;
}
}();
return c10::scalarTypeToTypeMeta(dtype);
}
torch::Tensor convert_to_tensor(
sox_sample_t* buffer,
const int32_t num_samples,
const int32_t num_channels,
const caffe2::TypeMeta dtype,
const bool normalize,
const bool channels_first) {
torch::Tensor t;
uint64_t dummy = 0;
SOX_SAMPLE_LOCALS;
if (normalize || dtype == torch::kFloat32) {
t = torch::empty(
{num_samples / num_channels, num_channels}, torch::kFloat32);
auto ptr = t.data_ptr<float_t>();
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy);
}
} else if (dtype == torch::kInt32) {
t = torch::from_blob(
buffer, {num_samples / num_channels, num_channels}, torch::kInt32)
.clone();
} else if (dtype == torch::kInt16) {
t = torch::empty({num_samples / num_channels, num_channels}, torch::kInt16);
auto ptr = t.data_ptr<int16_t>();
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy);
}
} else if (dtype == torch::kUInt8) {
t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8);
auto ptr = t.data_ptr<uint8_t>();
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy);
}
} else {
throw std::runtime_error("Unsupported dtype.");
}
if (channels_first) {
t = t.transpose(1, 0);
}
return t.contiguous();
}
const std::string get_filetype(const std::string path) {
std::string ext = path.substr(path.find_last_of(".") + 1);
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
return ext;
}
namespace {
std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
const std::string format,
caffe2::TypeMeta dtype,
const Encoding& encoding,
const BitDepth& bits_per_sample) {
switch (encoding) {
case Encoding::NOT_PROVIDED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
switch (dtype.toScalarType()) {
case c10::ScalarType::Float:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
case c10::ScalarType::Int:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
case c10::ScalarType::Short:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
case c10::ScalarType::Byte:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error("Internal Error: Unexpected dtype.");
}
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
}
case Encoding::PCM_SIGNED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
case BitDepth::B8:
throw std::runtime_error(
format + " does not support 8-bit signed PCM encoding.");
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
}
case Encoding::PCM_UNSIGNED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for unsigned PCM encoding.");
}
case Encoding::PCM_FLOAT:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B32:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
case BitDepth::B64:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 64);
default:
throw std::runtime_error(
format +
" only supports 32-bit or 64-bit for floating-point PCM encoding.");
}
case Encoding::ULAW:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for mu-law encoding.");
}
case Encoding::ALAW:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for a-law encoding.");
}
default:
throw std::runtime_error(
format + " does not support encoding: " + to_string(encoding));
}
}
std::tuple<sox_encoding_t, unsigned> get_save_encoding(
const std::string& format,
const caffe2::TypeMeta dtype,
const c10::optional<std::string> encoding,
const c10::optional<int64_t> bits_per_sample) {
const Format fmt = get_format_from_string(format);
const Encoding enc = get_encoding_from_option(encoding);
const BitDepth bps = get_bit_depth_from_option(bits_per_sample);
switch (fmt) {
case Format::WAV:
case Format::AMB:
return get_save_encoding_for_wav(format, dtype, enc, bps);
case Format::MP3:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("mp3 does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"mp3 does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_MP3, 16);
case Format::HTK:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("htk does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"htk does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
case Format::VORBIS:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("vorbis does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"vorbis does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_VORBIS, 16);
case Format::AMR_NB:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("amr-nb does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"amr-nb does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16);
case Format::FLAC:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("flac does not support `encoding` option.");
switch (bps) {
case BitDepth::B32:
case BitDepth::B64:
throw std::runtime_error(
"flac does not support `bits_per_sample` larger than 24.");
default:
return std::make_tuple<>(
SOX_ENCODING_FLAC, static_cast<unsigned>(bps));
}
case Format::SPHERE:
switch (enc) {
case Encoding::NOT_PROVIDED:
case Encoding::PCM_SIGNED:
switch (bps) {
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bps));
}
case Encoding::PCM_UNSIGNED:
throw std::runtime_error(
"sph does not support unsigned integer PCM.");
case Encoding::PCM_FLOAT:
throw std::runtime_error("sph does not support floating point PCM.");
case Encoding::ULAW:
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
throw std::runtime_error(
"sph only supports 8-bit for mu-law encoding.");
}
case Encoding::ALAW:
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
return std::make_tuple<>(
SOX_ENCODING_ALAW, static_cast<unsigned>(bps));
}
default:
throw std::runtime_error(
"sph does not support encoding: " + encoding.value());
}
case Format::GSM:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("gsm does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"gsm does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_GSM, 16);
default:
throw std::runtime_error("Unsupported format: " + format);
}
}
unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
if (filetype == "mp3")
return SOX_UNSPEC;
if (filetype == "flac")
return 24;
if (filetype == "ogg" || filetype == "vorbis")
return SOX_UNSPEC;
if (filetype == "wav" || filetype == "amb") {
switch (dtype.toScalarType()) {
case c10::ScalarType::Byte:
return 8;
case c10::ScalarType::Short:
return 16;
case c10::ScalarType::Int:
return 32;
case c10::ScalarType::Float:
return 32;
default:
throw std::runtime_error("Unsupported dtype.");
}
}
if (filetype == "sph")
return 32;
if (filetype == "amr-nb") {
return 16;
}
if (filetype == "gsm") {
return 16;
}
if (filetype == "htk") {
return 16;
}
throw std::runtime_error("Unsupported file type: " + filetype);
}
} // namespace
sox_signalinfo_t get_signalinfo(
const torch::Tensor* waveform,
const int64_t sample_rate,
const std::string filetype,
const bool channels_first) {
return sox_signalinfo_t{
/*rate=*/static_cast<sox_rate_t>(sample_rate),
/*channels=*/
static_cast<unsigned>(waveform->size(channels_first ? 0 : 1)),
/*precision=*/get_precision(filetype, waveform->dtype()),
/*length=*/static_cast<uint64_t>(waveform->numel())};
}
sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
sox_encoding_t encoding = [&]() {
switch (dtype.toScalarType()) {
case c10::ScalarType::Byte:
return SOX_ENCODING_UNSIGNED;
case c10::ScalarType::Short:
return SOX_ENCODING_SIGN2;
case c10::ScalarType::Int:
return SOX_ENCODING_SIGN2;
case c10::ScalarType::Float:
return SOX_ENCODING_FLOAT;
default:
throw std::runtime_error("Unsupported dtype.");
}
}();
unsigned bits_per_sample = [&]() {
switch (dtype.toScalarType()) {
case c10::ScalarType::Byte:
return 8;
case c10::ScalarType::Short:
return 16;
case c10::ScalarType::Int:
return 32;
case c10::ScalarType::Float:
return 32;
default:
throw std::runtime_error("Unsupported dtype.");
}
}();
return sox_encodinginfo_t{
/*encoding=*/encoding,
/*bits_per_sample=*/bits_per_sample,
/*compression=*/HUGE_VAL,
/*reverse_bytes=*/sox_option_default,
/*reverse_nibbles=*/sox_option_default,
/*reverse_bits=*/sox_option_default,
/*opposite_endian=*/sox_false};
}
sox_encodinginfo_t get_encodinginfo_for_save(
const std::string& format,
const caffe2::TypeMeta dtype,
const c10::optional<double> compression,
const c10::optional<std::string> encoding,
const c10::optional<int64_t> bits_per_sample) {
auto enc = get_save_encoding(format, dtype, encoding, bits_per_sample);
return sox_encodinginfo_t{
/*encoding=*/std::get<0>(enc),
/*bits_per_sample=*/std::get<1>(enc),
/*compression=*/compression.value_or(HUGE_VAL),
/*reverse_bytes=*/sox_option_default,
/*reverse_nibbles=*/sox_option_default,
/*reverse_bits=*/sox_option_default,
/*opposite_endian=*/sox_false};
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::sox_utils_set_seed", &torchaudio::sox_utils::set_seed);
m.def(
"torchaudio::sox_utils_set_verbosity",
&torchaudio::sox_utils::set_verbosity);
m.def(
"torchaudio::sox_utils_set_use_threads",
&torchaudio::sox_utils::set_use_threads);
m.def(
"torchaudio::sox_utils_set_buffer_size",
&torchaudio::sox_utils::set_buffer_size);
m.def(
"torchaudio::sox_utils_list_effects",
&torchaudio::sox_utils::list_effects);
m.def(
"torchaudio::sox_utils_list_read_formats",
&torchaudio::sox_utils::list_read_formats);
m.def(
"torchaudio::sox_utils_list_write_formats",
&torchaudio::sox_utils::list_write_formats);
m.def(
"torchaudio::sox_utils_get_buffer_size",
&torchaudio::sox_utils::get_buffer_size);
}
} // namespace sox_utils
} // namespace torchaudio
#ifndef TORCHAUDIO_SOX_UTILS_H
#define TORCHAUDIO_SOX_UTILS_H
#include <sox.h>
#include <torch/script.h>
namespace torchaudio {
namespace sox_utils {
////////////////////////////////////////////////////////////////////////////////
// APIs for Python interaction
////////////////////////////////////////////////////////////////////////////////
/// Set sox global options
void set_seed(const int64_t seed);
void set_verbosity(const int64_t verbosity);
void set_use_threads(const bool use_threads);
void set_buffer_size(const int64_t buffer_size);
int64_t get_buffer_size();
std::vector<std::vector<std::string>> list_effects();
std::vector<std::string> list_read_formats();
std::vector<std::string> list_write_formats();
////////////////////////////////////////////////////////////////////////////////
// Utilities for sox_io / sox_effects implementations
////////////////////////////////////////////////////////////////////////////////
const std::unordered_set<std::string> UNSUPPORTED_EFFECTS =
{"input", "output", "spectrogram", "noiseprof", "noisered", "splice"};
/// helper class to automatically close sox_format_t*
struct SoxFormat {
explicit SoxFormat(sox_format_t* fd) noexcept;
SoxFormat(const SoxFormat& other) = delete;
SoxFormat(SoxFormat&& other) = delete;
SoxFormat& operator=(const SoxFormat& other) = delete;
SoxFormat& operator=(SoxFormat&& other) = delete;
~SoxFormat();
sox_format_t* operator->() const noexcept;
operator sox_format_t*() const noexcept;
void close();
private:
sox_format_t* fd_;
};
///
/// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf, const std::string& path);
/// Verify that input memory buffer has known encoding, and not empty
void validate_input_memfile(const SoxFormat& sf);
///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
void validate_input_tensor(const torch::Tensor);
///
/// Get target dtype for the given encoding and precision.
caffe2::TypeMeta get_dtype(
const sox_encoding_t encoding,
const unsigned precision);
///
/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
/// NOTE: This function might modify the values in the input buffer to
/// reduce the number of memory copy.
/// @param buffer Pointer to buffer that contains audio data.
/// @param num_samples The number of samples to read.
/// @param num_channels The number of channels. Used to reshape the resulting
/// Tensor.
/// @param dtype Target dtype. Determines the output dtype and value range in
/// conjunction with normalization.
/// @param noramlize Perform normalization. Only effective when dtype is not
/// kFloat32. When effective, the output tensor is kFloat32 type and value range
/// is [-1.0, 1.0]
/// @param channels_first When True, output Tensor has shape of [num_channels,
/// num_frames].
torch::Tensor convert_to_tensor(
sox_sample_t* buffer,
const int32_t num_samples,
const int32_t num_channels,
const caffe2::TypeMeta dtype,
const bool normalize,
const bool channels_first);
/// Extract extension from file path
const std::string get_filetype(const std::string path);
/// Get sox_signalinfo_t for passing a torch::Tensor object.
sox_signalinfo_t get_signalinfo(
const torch::Tensor* waveform,
const int64_t sample_rate,
const std::string filetype,
const bool channels_first);
/// Get sox_encodinginfo_t for Tensor I/O
sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype);
/// Get sox_encodinginfo_t for saving to file/file object
sox_encodinginfo_t get_encodinginfo_for_save(
const std::string& format,
const caffe2::TypeMeta dtype,
const c10::optional<double> compression,
const c10::optional<std::string> encoding,
const c10::optional<int64_t> bits_per_sample);
} // namespace sox_utils
} // namespace torchaudio
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment