Unverified Commit 0c263a93 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Replace existing prototype RNNT Loss (#1479)

Replace the prototype RNNT implementation (using warp-transducer) with one without external library dependencies
parent b5d80279
#pragma once
#include <torchaudio/csrc/rnnt/cpu/cpu_kernels.h>
#include <torchaudio/csrc/rnnt/workspace.h>
namespace torchaudio {
namespace rnnt {
namespace cpu {
// Inputs:
// workspace: workspace.
// logits: pointer to (B, maxT, maxU, D) logits.
// targets: pointer to (B, maxU - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, maxT, maxU, D) gradients in the batch.
template <typename DTYPE, typename CAST_DTYPE>
status_t Compute(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* gradients = nullptr) {
const Options& options = workspace.GetOptions();
CHECK_EQ(options.device_, CPU);
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
{ // compute denominators.
LogSumExp2D<DTYPE, CAST_DTYPE>(
/*N=*/B * maxT * maxU,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
}
{ // compute log prob pairs.
ComputeLogProbs<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs());
}
{ // compute alphas and betas.
ComputeAlphasBetas<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*costs=*/costs);
}
if (gradients != nullptr) {
ComputeGradients<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients);
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeAlphas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* alphas) {
const Options& options = workspace.GetOptions();
CHECK_EQ(options.device_, CPU);
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
{ // compute denominators.
LogSumExp2D<DTYPE, CAST_DTYPE>(
/*N=*/B * maxT * maxU,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
}
{ // compute log prob pairs.
ComputeLogProbs<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs());
}
{ // compute alphas.
ComputeAlphas<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alphas=*/alphas);
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeBetas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* betas) {
const Options& options = workspace.GetOptions();
CHECK_EQ(options.device_, CPU);
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
{ // compute denominators.
LogSumExp2D<DTYPE, CAST_DTYPE>(
/*N=*/B * maxT * maxU,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
}
{ // compute log prob pairs.
ComputeLogProbs<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs());
}
{ // compute betas.
ComputeBetas<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*betas=*/betas);
}
return SUCCESS;
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/cpu/math.h>
namespace torchaudio {
namespace rnnt {
inline HOST_AND_DEVICE bool in_range(
int start,
int end, // inclusive
int val) {
return start <= val && val <= end;
}
#define LOG_PROBS_SKIP_IDX 0
#define LOG_PROBS_EMIT_IDX 1
struct Indexer2D {
const int& size2_;
FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {}
FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) {
return index1 * size2_ + index2;
}
};
struct Indexer3D {
const int& size2_;
const int& size3_;
FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3)
: size2_(size2), size3_(size3) {}
FORCE_INLINE HOST_AND_DEVICE int operator()(
int index1,
int index2,
int index3) {
return (index1 * size2_ + index2) * size3_ + index3;
}
};
struct Indexer4D {
const int& size2_;
const int& size3_;
const int& size4_;
HOST_AND_DEVICE Indexer4D(
const int& size2,
const int& size3,
const int& size4)
: size2_(size2), size3_(size3), size4_(size4) {}
HOST_AND_DEVICE int operator()(
int index1,
int index2,
int index3,
int index4) {
return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4;
}
};
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/rnnt/macros.h>
namespace torchaudio {
namespace rnnt {
namespace math {
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) {
if (x > y)
return x;
else
return y;
}
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) {
if (x > y)
return y;
else
return x;
}
// log_sum_exp
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y);
template <>
FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) {
if (y > x) {
return y + log1pf(expf(x - y));
} else {
return x + log1pf(expf(y - x));
}
}
} // namespace math
} // namespace rnnt
} // namespace torchaudio
#include <torchaudio/csrc/rnnt/macros.h>
const char* ToString(level_t level) {
switch (level) {
case INFO:
return "INFO";
case WARNING:
return "WARNING";
case ERROR:
return "ERROR";
case FATAL:
return "FATAL";
default:
return "UNKNOWN";
}
}
#pragma once
#ifdef USE_CUDA
#define WARP_SIZE 32
#define MAX_THREADS_PER_BLOCK 1024
#define REDUCE_THREADS 256
#define HOST_AND_DEVICE __host__ __device__
#define FORCE_INLINE __forceinline__
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#else
#define HOST_AND_DEVICE
#define FORCE_INLINE inline
#endif // USE_CUDA
#include <cstring>
#include <iostream>
typedef enum { INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 } level_t;
const char* ToString(level_t level);
#pragma once
//#include <iostream>
#ifdef USE_CUDA
#include <cuda_runtime.h>
#endif // USE_CUDA
#include <torchaudio/csrc/rnnt/macros.h>
#include <torchaudio/csrc/rnnt/types.h>
namespace torchaudio {
namespace rnnt {
typedef struct Options {
// the device to compute transducer loss.
device_t device_;
#ifdef USE_CUDA
// the stream to launch kernels in when using GPU.
cudaStream_t stream_;
#endif
// The maximum number of threads that can be used.
int numThreads_;
// the index for "blank".
int blank_;
// whether to backtrack the best path.
bool backtrack_;
// gradient clamp value.
float clamp_;
// batch size = B.
int batchSize_;
// Number of hypos per sample = H
int nHypos_;
// the maximum length of src encodings = max_T.
int maxSrcLen_;
// the maximum length of tgt encodings = max_U.
int maxTgtLen_;
// num_targets = D.
int numTargets_;
// if set to true, inputs are logits and gradients are
// fused with logsoftmax gradients.
// if set to false, log_softmax is computed outside of loss
// True by default
bool fusedLogSmax_;
Options()
: device_(UNDEFINED),
numThreads_(0),
blank_(-1),
backtrack_(false),
clamp_(-1), // negative for disabling clamping by default.
batchSize_(0),
nHypos_(1),
maxSrcLen_(0),
maxTgtLen_(0),
numTargets_(0),
fusedLogSmax_(true) {}
int BU() const {
return batchSize_ * maxTgtLen_ * nHypos_;
}
int BTU() const {
return batchSize_ * maxSrcLen_ * maxTgtLen_ * nHypos_;
}
friend std::ostream& operator<<(std::ostream& os, const Options& options) {
os << "Options("
<< "batchSize_=" << options.batchSize_ << ", "
<< "maxSrcLen_=" << options.maxSrcLen_ << ", "
<< "maxTgtLen_=" << options.maxTgtLen_ << ", "
<< "numTargets_=" << options.numTargets_ << ")";
return os;
}
} Options;
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
template <typename DTYPE, typename CAST_DTYPE>
status_t Compute(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* gradients = nullptr) {
switch (workspace.GetOptions().device_) {
case CPU: {
status_t status = cpu::Compute<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*gradients=*/gradients);
return status;
}
case GPU: {
status_t status = gpu::Compute<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*gradients=*/gradients);
return status;
}
default: {
return FAILURE;
}
};
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeAlphas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* alphas) {
switch (workspace.GetOptions().device_) {
case CPU: {
status_t status = cpu::ComputeAlphas<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alphas=*/alphas);
return status;
}
case GPU: {
status_t status = gpu::ComputeAlphas<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/alphas);
return status;
}
default: {
return FAILURE;
}
};
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeBetas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* betas) {
switch (workspace.GetOptions().device_) {
case CPU: {
status_t status = cpu::ComputeBetas<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*betas=*/betas);
return status;
}
case GPU: {
status_t status = gpu::ComputeBetas<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*betas=*/betas);
return status;
}
default: {
return FAILURE;
}
};
}
} // namespace rnnt
} // namespace torchaudio
#include <torchaudio/csrc/rnnt/types.h>
namespace torchaudio {
namespace rnnt {
const char* toString(status_t status) {
switch (status) {
case SUCCESS:
return "success";
case FAILURE:
return "failure";
case COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED:
return "compute_denominator_reduce_max_failed";
case COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED:
return "compute_denominator_reduce_sum_failed";
case COMPUTE_LOG_PROBS_FAILED:
return "compute_log_probs_failed";
case COMPUTE_ALPHAS_BETAS_COSTS_FAILED:
return "compute_alphas_betas_costs_failed";
case COMPUTE_GRADIENTS_FAILED:
return "compute_gradients_failed";
default:
return "unknown";
}
}
const char* toString(device_t device) {
switch (device) {
case UNDEFINED:
return "undefined";
case CPU:
return "cpu";
case GPU:
return "gpu";
default:
return "unknown";
}
}
} // namespace rnnt
} // namespace torchaudio
#pragma once
namespace torchaudio {
namespace rnnt {
typedef enum {
SUCCESS = 0,
FAILURE = 1,
COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED = 2,
COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED = 3,
COMPUTE_LOG_PROBS_FAILED = 4,
COMPUTE_ALPHAS_BETAS_COSTS_FAILED = 5,
COMPUTE_GRADIENTS_FAILED = 6
} status_t;
typedef enum { UNDEFINED = 0, CPU = 1, GPU = 2 } device_t;
const char* toString(status_t status);
const char* toString(device_t device);
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <cstring>
#include <vector>
#include <torchaudio/csrc/rnnt/options.h>
namespace torchaudio {
namespace rnnt {
// Since CUDA has strict memory alignment, it's better to keep allocated memory
// blocks separate for different data types.
// DtypeWorkspace holds a "view" of workspace for:
// 1. softmax denominators (in log form), size = B * max_T * max_U
// 2. log probibility pairs for blank and target, size = B * max_T * max_U
// 3. alphas, size = B * max_T * max_U
// 4. betas, size = B * max_T * max_U
template <typename DTYPE>
class DtypeWorkspace {
public:
DtypeWorkspace() : options_(), size_(0), data_(nullptr) {}
DtypeWorkspace(const Options& options, DTYPE* data, int size)
: DtypeWorkspace() {
Reset(options, data, size);
}
~DtypeWorkspace() {}
static int ComputeSizeFromOptions(const Options& options) {
CHECK_NE(options.device_, UNDEFINED);
return ComputeSizeForDenominators(options) +
ComputeSizeForLogProbs(options) + ComputeSizeForAlphas(options) +
ComputeSizeForBetas(options);
}
void Free();
void Reset(const Options& options, DTYPE* data, int size) {
int needed_size = ComputeSizeFromOptions(options);
CHECK_LE(needed_size, size);
options_ = options;
data_ = data;
size_ = size;
}
int Size() const {
return size_;
}
DTYPE* GetPointerToDenominators() const {
return data_;
}
DTYPE* GetPointerToLogProbs() const {
return GetPointerToDenominators() + ComputeSizeForDenominators(options_);
}
DTYPE* GetPointerToAlphas() const {
return GetPointerToLogProbs() + ComputeSizeForLogProbs(options_);
}
DTYPE* GetPointerToBetas() const {
return GetPointerToAlphas() + ComputeSizeForAlphas(options_);
}
private:
static int ComputeSizeForDenominators(const Options& options) { // B * T * U
return options.BTU();
}
static int ComputeSizeForLogProbs(const Options& options) { // B * T * U * 2
return options.BTU() * 2;
}
static int ComputeSizeForAlphas(const Options& options) { // B * T * U
return options.BTU();
}
static int ComputeSizeForBetas(const Options& options) { // B * T * U
return options.BTU();
}
Options options_;
int size_; // number of elements in allocated memory.
DTYPE* data_; // pointer to the allocated memory.
};
// IntWorkspace holds a "view" of workspace for:
// 1. alpha counters, size = B * max_U
// 2. beta counters, size = B * max_U
class IntWorkspace {
public:
IntWorkspace() : options_(), size_(0), data_(nullptr) {}
IntWorkspace(const Options& options, int* data, int size) : IntWorkspace() {
Reset(options, data, size);
}
~IntWorkspace() {}
static int ComputeSizeFromOptions(const Options& options) {
return ComputeSizeForAlphaCounters(options) +
ComputeSizeForBetaCounters(options);
}
void Reset(const Options& options, int* data, int size) {
int needed_size = ComputeSizeFromOptions(options);
CHECK_LE(needed_size, size);
options_ = options;
data_ = data;
size_ = size;
ResetAlphaBetaCounters();
}
int Size() const {
return size_;
}
int* GetPointerToAlphaCounters() const {
CHECK_EQ(options_.device_, GPU);
return data_;
}
int* GetPointerToBetaCounters() const {
CHECK_EQ(options_.device_, GPU);
return GetPointerToAlphaCounters() + ComputeSizeForAlphaCounters(options_);
}
private:
inline void ResetAlphaBetaCounters() {
#ifdef USE_CUDA
if (data_ != nullptr && options_.device_ == GPU) {
cudaMemset(
GetPointerToAlphaCounters(),
0,
ComputeSizeForAlphaCounters(options_) * sizeof(int));
cudaMemset(
GetPointerToBetaCounters(),
0,
ComputeSizeForBetaCounters(options_) * sizeof(int));
}
#endif // USE_CUDA
}
static int ComputeSizeForAlphaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
if (options.device_ == GPU) {
return options.BU();
} else {
return 0;
}
#else
return 0;
#endif // USE_CUDA
}
static int ComputeSizeForBetaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
if (options.device_ == GPU) {
return options.BU();
} else {
return 0;
}
#else
return 0;
#endif // USE_CUDA
}
Options options_;
int size_; // number of elements in allocated memory.
int* data_; // pointer to the allocated memory.
};
// Workspace<DTYPE> holds:
// 1. DtypeWorkspace<DTYPE>
// 2. IntWorkspace
template <typename DTYPE>
class Workspace {
public:
Workspace() : options_(), dtype_workspace_(), int_workspace_() {}
Workspace(
const Options& options,
DTYPE* dtype_data,
int dtype_size,
int* int_data,
int int_size)
: Workspace() {
Reset(options, dtype_data, dtype_size, int_data, int_size);
}
~Workspace() {}
void Reset(
const Options& options,
DTYPE* dtype_data,
int dtype_size,
int* int_data,
int int_size) {
options_ = options;
dtype_workspace_.Reset(options_, dtype_data, dtype_size);
int_workspace_.Reset(options_, int_data, int_size);
}
const Options& GetOptions() const {
return options_;
}
DTYPE* GetPointerToDenominators() const {
return dtype_workspace_.GetPointerToDenominators();
}
DTYPE* GetPointerToLogProbs() const {
return dtype_workspace_.GetPointerToLogProbs();
}
DTYPE* GetPointerToAlphas() const {
return dtype_workspace_.GetPointerToAlphas();
}
DTYPE* GetPointerToBetas() const {
return dtype_workspace_.GetPointerToBetas();
}
int* GetPointerToAlphaCounters() const {
return int_workspace_.GetPointerToAlphaCounters();
}
int* GetPointerToBetaCounters() const {
return int_workspace_.GetPointerToBetaCounters();
}
private:
Options options_;
DtypeWorkspace<DTYPE> dtype_workspace_;
IntWorkspace int_workspace_;
};
} // namespace rnnt
} // namespace torchaudio
#include <iostream>
#include <numeric>
#include <string>
#include <vector>
#include <torch/script.h>
#include "rnnt.h"
namespace {
int64_t cpu_rnnt_loss(
torch::Tensor acts,
torch::Tensor labels,
torch::Tensor input_lengths,
torch::Tensor label_lengths,
torch::Tensor costs,
torch::Tensor grads,
int64_t blank_label,
int64_t num_threads) {
TORCH_CHECK(labels.dtype() == torch::kInt32, "labels must be int32 type");
TORCH_CHECK(
label_lengths.dtype() == torch::kInt32,
"label_lengths must be int32 type");
TORCH_CHECK(
input_lengths.dtype() == torch::kInt32, "lengths must be int32 type");
TORCH_CHECK(acts.is_contiguous(), "acts must be contiguous");
TORCH_CHECK(labels.is_contiguous(), "labels must be contiguous");
TORCH_CHECK(
label_lengths.is_contiguous(), "label_lengths must be contiguous");
TORCH_CHECK(input_lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(
input_lengths.size(0) == acts.size(0),
"batch dimension mismatch between acts and input_lengths: each example must have a length");
TORCH_CHECK(
label_lengths.size(0) == acts.size(0),
"batch dimension mismatch between acts and label_lengths: each example must have a label length");
TORCH_CHECK(acts.dim() == 4, "acts must be 4-D (batch, time, label, class)");
TORCH_CHECK(
labels.dim() == 2, "labels must be 2-D (batch, max label length)");
TORCH_CHECK(input_lengths.dim() == 1, "input_lengths must be 1-D");
TORCH_CHECK(label_lengths.dim() == 1, "label_lengths must be 1-D");
int maxT = acts.size(1);
int maxU = acts.size(2);
int minibatch_size = acts.size(0);
int alphabet_size = acts.size(3);
TORCH_CHECK(
at::max(input_lengths).item().toInt() == maxT, "input length mismatch");
TORCH_CHECK(
at::max(label_lengths).item().toInt() + 1 == maxU,
"output length mismatch");
rnntOptions options;
memset(&options, 0, sizeof(options));
options.maxT = maxT;
options.maxU = maxU;
options.blank_label = blank_label;
options.batch_first = true;
options.loc = RNNT_CPU;
options.num_threads = num_threads;
// have to use at least one
options.num_threads = std::max(options.num_threads, (unsigned int)1);
size_t cpu_size_bytes = 0;
switch (acts.scalar_type()) {
case torch::ScalarType::Float: {
get_workspace_size(maxT, maxU, minibatch_size, false, &cpu_size_bytes);
std::vector<float> cpu_workspace(cpu_size_bytes / sizeof(float), 0);
compute_rnnt_loss(
acts.data_ptr<float>(),
grads.data_ptr<float>(),
labels.data_ptr<int>(),
label_lengths.data_ptr<int>(),
input_lengths.data_ptr<int>(),
alphabet_size,
minibatch_size,
costs.data_ptr<float>(),
cpu_workspace.data(),
options);
return 0;
}
case torch::ScalarType::Double: {
get_workspace_size(
maxT, maxU, minibatch_size, false, &cpu_size_bytes, sizeof(double));
std::vector<double> cpu_workspace(cpu_size_bytes / sizeof(double), 0);
compute_rnnt_loss_fp64(
acts.data_ptr<double>(),
grads.data_ptr<double>(),
labels.data_ptr<int>(),
label_lengths.data_ptr<int>(),
input_lengths.data_ptr<int>(),
alphabet_size,
minibatch_size,
costs.data_ptr<double>(),
cpu_workspace.data(),
options);
return 0;
}
default:
TORCH_CHECK(
false,
std::string(__func__) + " not implemented for '" +
toString(acts.scalar_type()) + "'");
}
return -1;
}
} // namespace
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss", &cpu_rnnt_loss);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss(Tensor acts,"
"Tensor labels,"
"Tensor input_lengths,"
"Tensor label_lengths,"
"Tensor costs,"
"Tensor grads,"
"int blank_label,"
"int num_threads) -> int");
}
import torch
__all__ = [
"RNNTLoss",
"rnnt_loss",
]
def _rnnt_loss_alphas(
logits,
targets,
logit_lengths,
target_lengths,
blank=-1,
clamp=-1,
):
"""
Compute alphas for RNN transducer loss.
See documentation for RNNTLoss
"""
targets = targets.to(device=logits.device)
logit_lengths = logit_lengths.to(device=logits.device)
target_lengths = target_lengths.to(device=logits.device)
# make sure all int tensors are of type int32.
targets = targets.int()
logit_lengths = logit_lengths.int()
target_lengths = target_lengths.int()
return torch.ops.torchaudio.rnnt_loss_alphas(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
)
def _rnnt_loss_betas(
logits,
targets,
logit_lengths,
target_lengths,
blank=-1,
clamp=-1,
):
"""
Compute betas for RNN transducer loss
See documentation for RNNTLoss
"""
targets = targets.to(device=logits.device)
logit_lengths = logit_lengths.to(device=logits.device)
target_lengths = target_lengths.to(device=logits.device)
# make sure all int tensors are of type int32.
targets = targets.int()
logit_lengths = logit_lengths.int()
target_lengths = target_lengths.int()
return torch.ops.torchaudio.rnnt_loss_betas(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
)
class _RNNT(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
targets,
logit_lengths,
target_lengths,
blank=-1,
clamp=-1,
runtime_check=False,
fused_log_softmax=True,
reuse_logits_for_grads=True,
):
"""
See documentation for RNNTLoss
"""
# move everything to the same device.
targets = targets.to(device=logits.device)
logit_lengths = logit_lengths.to(device=logits.device)
target_lengths = target_lengths.to(device=logits.device)
# make sure all int tensors are of type int32.
targets = targets.int()
logit_lengths = logit_lengths.int()
target_lengths = target_lengths.int()
if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank
if runtime_check:
check_inputs(
logits=logits,
targets=targets,
logit_lengths=logit_lengths,
target_lengths=target_lengths,
blank=blank,
)
costs, gradients = torch.ops.torchaudio.rnnt_loss(
logits=logits,
targets=targets,
src_lengths=logit_lengths,
tgt_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_smax=fused_log_softmax,
reuse_logits_for_grads=reuse_logits_for_grads,
)
ctx.grads = gradients
return costs
@staticmethod
def backward(ctx, output_gradients):
output_gradients = output_gradients.view(-1, 1, 1, 1).to(ctx.grads)
ctx.grads.mul_(output_gradients).to(ctx.grads)
return (
ctx.grads, # logits
None, # targets
None, # logit_lengths
None, # target_lengths
None, # blank
None, # clamp
None, # runtime_check
None, # fused_log_softmax
None, # reuse_logits_for_grads
)
def rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
blank=-1,
clamp=-1,
runtime_check=False,
fused_log_softmax=True,
reuse_logits_for_grads=True,
):
"""
Compute the RNN Transducer Loss.
The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
Args:
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
"""
if not fused_log_softmax:
logits = torch.nn.functional.log_softmax(logits, dim=-1)
reuse_logits_for_grads = (
False # softmax needs the original logits value
)
cost = _RNNT.apply(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
runtime_check,
fused_log_softmax,
reuse_logits_for_grads,
)
return cost
class RNNTLoss(torch.nn.Module):
"""
Compute the RNN Transducer Loss.
The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
Args:
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
"""
def __init__(
self,
blank=-1,
clamp=-1,
runtime_check=False,
fused_log_softmax=True,
reuse_logits_for_grads=True,
):
super().__init__()
self.blank = blank
self.clamp = clamp
self.runtime_check = runtime_check
self.fused_log_softmax = fused_log_softmax
self.reuse_logits_for_grads = reuse_logits_for_grads
def forward(
self,
logits,
targets,
logit_lengths,
target_lengths,
):
"""
Args:
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
"""
return rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
self.blank,
self.clamp,
self.runtime_check,
self.fused_log_softmax,
self.reuse_logits_for_grads,
)
def check_type(var, t, name):
if var.dtype is not t:
raise TypeError("{} must be {}".format(name, t))
def check_contiguous(var, name):
if not var.is_contiguous():
raise ValueError("{} must be contiguous".format(name))
def check_dim(var, dim, name):
if len(var.shape) != dim:
raise ValueError("{} must be {}D".format(name, dim))
def check_equal(var1, name1, var2, name2):
if var1 != var2:
raise ValueError(
"`{}` ({}) must equal to ".format(name1, var1)
+ "`{}` ({})".format(name2, var2)
)
def check_device(var1, name1, var2, name2):
if var1.device != var2.device:
raise ValueError(
"`{}` ({}) must be on the same ".format(name1, var1.device.type)
+ "device as `{}` ({})".format(name2, var2.device.type)
)
def check_inputs(logits, targets, logit_lengths, target_lengths, blank):
check_device(logits, "logits", targets, "targets")
check_device(logits, "logits", targets, "logit_lengths")
check_device(logits, "logits", targets, "target_lengths")
check_type(logits, torch.float32, "logits")
check_type(targets, torch.int32, "targets")
check_type(logit_lengths, torch.int32, "logit_lengths")
check_type(target_lengths, torch.int32, "target_lengths")
check_contiguous(logits, "logits")
check_contiguous(targets, "targets")
check_contiguous(target_lengths, "target_lengths")
check_contiguous(logit_lengths, "logit_lengths")
check_dim(logits, 4, "logits")
check_dim(targets, 2, "targets")
check_dim(logit_lengths, 1, "logit_lengths")
check_dim(target_lengths, 1, "target_lengths")
check_equal(
logit_lengths.shape[0], "logit_lengths.shape[0]", logits.shape[0], "logits.shape[0]"
)
check_equal(
target_lengths.shape[0], "target_lengths.shape[0]", logits.shape[0], "logits.shape[0]"
)
check_equal(
targets.shape[0], "targets.shape[0]", logits.shape[0], "logits.shape[0]"
)
check_equal(
targets.shape[1],
"targets.shape[1]",
torch.max(target_lengths),
"torch.max(target_lengths)",
)
check_equal(
logits.shape[1],
"logits.shape[1]",
torch.max(logit_lengths),
"torch.max(logit_lengths)",
)
check_equal(
logits.shape[2],
"logits.shape[2]",
torch.max(target_lengths) + 1,
"torch.max(target_lengths) + 1",
)
if blank < 0 or blank >= logits.shape[-1]:
raise ValueError(
"blank ({}) must be within [0, logits.shape[-1]={})".format(
blank, logits.shape[-1]
)
)
import torch
from torch.autograd import Function
from torch.nn import Module
from torchaudio._internal import (
module_utils as _mod_utils,
)
__all__ = [
"rnnt_loss",
"RNNTLoss",
]
class _RNNT(Function):
@staticmethod
def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction):
"""
See documentation for RNNTLoss.
"""
device = acts.device
acts = acts.to("cpu")
labels = labels.to("cpu")
act_lens = act_lens.to("cpu")
label_lens = label_lens.to("cpu")
loss_func = torch.ops.torchaudio.rnnt_loss
grads = torch.zeros_like(acts)
minibatch_size = acts.size(0)
costs = torch.zeros(minibatch_size, dtype=acts.dtype)
loss_func(acts, labels, act_lens, label_lens, costs, grads, blank, 0)
if reduction in ["sum", "mean"]:
costs = costs.sum().unsqueeze_(-1)
if reduction == "mean":
costs /= minibatch_size
grads /= minibatch_size
costs = costs.to(device)
ctx.grads = grads.to(device)
return costs
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul_(grad_output), None, None, None, None, None
@_mod_utils.requires_module("torchaudio._torchaudio")
def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"):
"""Compute the RNN Transducer Loss.
The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
The implementation uses `warp-transducer <https://github.com/HawkAaron/warp-transducer>`__.
Args:
acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network
before applying ``torch.nn.functional.log_softmax``.
labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero
act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
blank (int): blank label. (Default: ``0``)
reduction (string): If ``'sum'``, the output losses will be summed.
If ``'mean'``, the output losses will be divided by the target lengths and
then the mean over the batch is taken. If ``'none'``, no reduction will be applied.
(Default: ``'mean'``)
"""
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)
return _RNNT.apply(acts, labels, act_lens, label_lens, blank, reduction)
@_mod_utils.requires_module("torchaudio._torchaudio")
class RNNTLoss(Module):
"""Compute the RNN Transducer Loss.
The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
The implementation uses `warp-transducer <https://github.com/HawkAaron/warp-transducer>`__.
Args:
blank (int): blank label. (Default: ``0``)
reduction (string): If ``'sum'``, the output losses will be summed.
If ``'mean'``, the output losses will be divided by the target lengths and
then the mean over the batch is taken. If ``'none'``, no reduction will be applied.
(Default: ``'mean'``)
"""
def __init__(self, blank=0, reduction="mean"):
super(RNNTLoss, self).__init__()
self.blank = blank
self.reduction = reduction
self.loss = _RNNT.apply
def forward(self, acts, labels, act_lens, label_lens):
"""
Args:
acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network
before applying ``torch.nn.functional.log_softmax``.
labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero
act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
"""
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)
return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment