Commit 9dcc7a15 authored by flyingdown's avatar flyingdown
Browse files

init v0.10.0

parent db2b0b79
Pipeline #254 failed with stages
in 0 seconds
#include <torchaudio/csrc/pybind/sox/utils.h>
namespace torchaudio::sox_utils {
auto read_fileobj(py::object* fileobj, const uint64_t size, char* buffer)
-> uint64_t {
uint64_t num_read = 0;
while (num_read < size) {
auto request = size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
}
return num_read;
}
} // namespace torchaudio::sox_utils
#ifndef TORCHAUDIO_PYBIND_SOX_UTILS_H
#define TORCHAUDIO_PYBIND_SOX_UTILS_H
#include <torch/extension.h>
namespace torchaudio::sox_utils {
auto read_fileobj(py::object* fileobj, uint64_t size, char* buffer) -> uint64_t;
} // namespace torchaudio::sox_utils
#endif
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
namespace torchaudio {
namespace rnnt {
class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
torch::Tensor undef;
auto result =
rnnt_loss(logits, targets, logit_lengths, target_lengths, blank, clamp);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
return {costs, grads};
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto grad = saved[0];
auto grad_out = grad_outputs[0].view({-1, 1, 1, 1});
auto result = grad * grad_out;
torch::Tensor undef;
return {result, undef, undef, undef, undef, undef, undef, undef};
}
};
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply(
logits, targets, logit_lengths, target_lengths, blank, clamp);
return std::make_tuple(results[0], results[1]);
}
TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
m.impl("rnnt_loss", rnnt_loss_autograd);
}
} // namespace rnnt
} // namespace torchaudio
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(logits, targets, logit_lengths, target_lengths, blank, clamp);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss(Tensor logits,"
"Tensor targets,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp) -> (Tensor, Tensor?)");
}
#pragma once
#include <torch/script.h>
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp);
#include <torch/script.h>
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss_alphas(Tensor logits,"
"Tensor targets,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp) -> Tensor");
}
#include <torch/script.h>
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss_betas(Tensor logits,"
"Tensor targets,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp) -> Tensor");
}
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace cpu {
// Entry point into RNNT Loss
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
TORCH_CHECK(
logits.device().type() == logit_lengths.device().type(),
"logits and logit_lengths must be on the same device");
TORCH_CHECK(
logits.device().type() == target_lengths.device().type(),
"logits and target_lengths must be on the same device");
TORCH_CHECK(
logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
"logits must be float32 or float16 (half) type");
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
TORCH_CHECK(
logit_lengths.dtype() == torch::kInt32,
"logit_lengths must be int32 type");
TORCH_CHECK(
target_lengths.dtype() == torch::kInt32,
"target_lengths must be int32 type");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(
target_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK(
logit_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths");
TORCH_CHECK(
target_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths");
TORCH_CHECK(
targets.size(0) == logits.size(0),
"batch dimension mismatch between logits and targets");
TORCH_CHECK(
blank >= 0 && blank < logits.size(-1),
"blank must be within [0, logits.shape[-1])");
TORCH_CHECK(
logits.size(1) == at::max(logit_lengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
logits.size(2) == at::max(target_lengths).item().toInt() + 1,
"output length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(target_lengths).item().toInt(),
"target length mismatch");
Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = torch::zeros_like(logits);
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
switch (logits.scalar_type()) {
case torch::ScalarType::Float: {
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/gradients->data_ptr<float>());
break;
}
case torch::ScalarType::Half: {
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/gradients->data_ptr<c10::Half>());
break;
}
default: {
break;
}
};
return std::make_tuple(costs, gradients);
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss", &compute);
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace cpu {
torch::Tensor compute_alphas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
torch::Tensor alphas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>());
return alphas;
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss_alphas", &compute_alphas);
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace cpu {
torch::Tensor compute_betas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
torch::Tensor costs = torch::empty(
target_lengths.size(0),
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor betas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*betas=*/betas.data_ptr<float>());
return betas;
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss_betas", &compute_betas);
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/rnnt/cpu/math.h>
#include <torchaudio/csrc/rnnt/options.h>
#include <torchaudio/csrc/rnnt/types.h>
#include <cstring>
#include <limits>
#include <vector>
namespace torchaudio {
namespace rnnt {
namespace cpu {
template <typename DTYPE>
struct LogProbs {
DTYPE skip_; // blank.
DTYPE emit_; // target.
LogProbs(DTYPE skip, DTYPE emit) : skip_(skip), emit_(emit) {}
DTYPE& skip() {
return skip_;
}
DTYPE& emit() {
return emit_;
}
const DTYPE& skip() const {
return skip_;
}
const DTYPE& emit() const {
return emit_;
}
};
// TensorView: view a block of allocated memory as a tensor.
template <typename DTYPE>
class TensorView {
public:
TensorView(const std::vector<int>& dims, DTYPE* data)
: dims_(dims), data_(data) {
strides_.resize(dims.size());
strides_.back() = 1;
for (int i = dims.size() - 2; i >= 0; --i) {
strides_[i] = strides_[i + 1] * dims[i + 1];
}
}
DTYPE& operator()(const std::vector<int>& indices) {
CHECK_EQ(indices.size(), dims_.size());
int index = indices.back();
for (int i = indices.size() - 2; i >= 0; --i) {
index += indices[i] * strides_[i];
}
return data_[index];
}
void SetZero() {
int size = dims_[0] * strides_[0];
std::memset(data_, 0, sizeof(DTYPE) * size);
}
private:
std::vector<int> dims_;
std::vector<int> strides_;
DTYPE* data_;
};
template <typename DTYPE, typename CAST_DTYPE>
status_t LogSumExp2D(int N, int D, const DTYPE* logits, CAST_DTYPE* outputs) {
for (int i = 0; i < N * D; i += D) {
CAST_DTYPE max = logits[i];
for (int j = 1; j < D; ++j) {
max = std::max(max, CAST_DTYPE(logits[i + j]));
}
CAST_DTYPE sum = 0;
for (int j = 0; j < D; ++j) {
sum = sum + std::exp(CAST_DTYPE(logits[i + j]) - max);
}
outputs[i / D] = max + std::log(sum);
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeLogProbsOneSequence(
const Options& options,
TensorView<const DTYPE>& logits,
const int* targets,
int srcLen,
int tgtLen,
TensorView<const CAST_DTYPE>& denom,
TensorView<LogProbs<CAST_DTYPE>>& logProbs) {
const int& T = srcLen;
const int& U = tgtLen;
const int& blank = options.blank_;
for (int t = 0; t < T; ++t) {
for (int u = 0; u < U; ++u) {
if (u < U - 1) {
logProbs({t, u}).emit() =
CAST_DTYPE(logits({t, u, targets[u]})) - denom({t, u});
}
logProbs({t, u}).skip() =
CAST_DTYPE(logits({t, u, blank})) - denom({t, u});
}
}
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeLogProbs(
const Options& options,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs) {
std::vector<TensorView<const DTYPE>> seqLogits;
std::vector<const int*> seqTargets;
std::vector<TensorView<const CAST_DTYPE>> seqDenoms;
std::vector<TensorView<LogProbs<CAST_DTYPE>>> seqlogProbs;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
for (int b = 0; b < B; ++b) {
seqLogits.push_back(
TensorView<const DTYPE>({maxT, maxU, D}, logits + b * maxT * maxU * D));
seqTargets.push_back(targets + b * (maxU - 1));
seqDenoms.push_back(TensorView<const CAST_DTYPE>(
{maxT, maxU}, denominators + b * maxT * maxU));
seqlogProbs.push_back(TensorView<LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(logProbs) + b * maxT * maxU));
}
//#pragma omp parallel for
for (int b = 0; b < B; ++b) { // use max 2 * B threads.
ComputeLogProbsOneSequence<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/seqLogits[b],
/*targets=*/seqTargets[b],
/*srcLen=*/srcLengths[b],
/*tgtLen=*/tgtLengths[b] + 1, // with prepended blank.
/*denom=*/seqDenoms[b],
/*logProbs=*/seqlogProbs[b]);
}
return SUCCESS;
}
template <typename DTYPE>
DTYPE ComputeAlphaOneSequence(
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen,
int tgtLen,
TensorView<DTYPE>& alpha) {
const int& T = srcLen;
const int& U = tgtLen;
alpha({0, 0}) = DTYPE(0);
for (int t = 1; t < T; ++t) { // u == 0.
alpha({t, 0}) = alpha({t - 1, 0}) + logProbs({t - 1, 0}).skip();
}
for (int u = 1; u < U; ++u) { // t == 0.
alpha({0, u}) = alpha({0, u - 1}) + logProbs({0, u - 1}).emit();
}
for (int t = 1; t < T; ++t) {
for (int u = 1; u < U; ++u) {
alpha({t, u}) = math::lse(
alpha({t - 1, u}) + logProbs({t - 1, u}).skip(),
alpha({t, u - 1}) + logProbs({t, u - 1}).emit());
}
}
DTYPE forward_score = alpha({T - 1, U - 1}) + logProbs({T - 1, U - 1}).skip();
return forward_score;
}
template <typename DTYPE>
DTYPE ComputeBetaOneSequence(
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen,
int tgtLen,
TensorView<DTYPE>& beta) {
const int& T = srcLen;
const int& U = tgtLen;
beta({T - 1, U - 1}) = logProbs({T - 1, U - 1}).skip();
for (int t = T - 2; t >= 0; --t) { // u == U - 1.
beta({t, U - 1}) = beta({t + 1, U - 1}) + logProbs({t, U - 1}).skip();
}
for (int u = U - 2; u >= 0; --u) { // t == T - 1.
beta({T - 1, u}) = beta({T - 1, u + 1}) + logProbs({T - 1, u}).emit();
}
for (int t = T - 2; t >= 0; --t) {
for (int u = U - 2; u >= 0; --u) {
beta({t, u}) = math::lse(
beta({t + 1, u}) + logProbs({t, u}).skip(),
beta({t, u + 1}) + logProbs({t, u}).emit());
}
}
DTYPE backward_score = beta({0, 0});
return backward_score;
}
template <typename DTYPE>
DTYPE ComputeAlphaOrBetaOneSequence(
int thread,
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen,
int tgtLen,
TensorView<DTYPE>& alpha,
TensorView<DTYPE>& beta) {
if (thread & 1) {
return ComputeAlphaOneSequence<DTYPE>(
/*options=*/options,
/*logProbs=*/logProbs,
/*srcLen=*/srcLen,
/*tgtLen=*/tgtLen,
/*alpha=*/alpha);
} else {
return ComputeBetaOneSequence<DTYPE>(
/*options=*/options,
/*logProbs=*/logProbs,
/*srcLen=*/srcLen,
/*tgtLen=*/tgtLen,
/*beta=*/beta);
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeAlphasBetas(
const Options& options,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
CAST_DTYPE* alphas,
CAST_DTYPE* betas,
DTYPE* costs) {
std::vector<TensorView<const LogProbs<CAST_DTYPE>>> seqlogProbs;
std::vector<TensorView<CAST_DTYPE>> seq_alphas;
std::vector<TensorView<CAST_DTYPE>> seq_betas;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int b = 0; b < B; ++b) {
seqlogProbs.push_back(TensorView<const LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(
const_cast<CAST_DTYPE*>(logProbs)) +
b * maxT * maxU));
seq_alphas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, alphas + b * maxT * maxU));
seq_betas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, betas + b * maxT * maxU));
}
std::vector<CAST_DTYPE> scores(B << 1);
//#pragma omp parallel for
for (int t = 0; t < (B << 1); ++t) { // use max 2 * B threads.
int i = (t >> 1);
scores[t] = ComputeAlphaOrBetaOneSequence<CAST_DTYPE>(
/*thread=*/t,
/*options=*/options,
/*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
/*alpha=*/seq_alphas[i],
/*beta=*/seq_betas[i]);
}
for (int b = 0; b < B; ++b) {
costs[b] = -scores[b << 1];
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeGradientsOneSequence(
const Options& options,
TensorView<const DTYPE>& logits,
const int* targets,
int srcLen,
int tgtLen,
TensorView<const CAST_DTYPE>& denom,
TensorView<const CAST_DTYPE>& alpha,
TensorView<const CAST_DTYPE>& beta,
TensorView<DTYPE>& gradients) {
// don't set gradients to zero to here as gradients might reuse memory from
// logits
const int& T = srcLen;
const int& U = tgtLen;
const int& D = options.numTargets_;
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
CAST_DTYPE cost = -beta({0, 0});
// Note - below gradient is different from numpy_transducer, since we
// compute log_softmax more efficiently within the loss, to save memory The
// details of the below implementation / equations can be found in Sec 3.2
// (function merging) in below paper:
// https://www.microsoft.com/en-us/research/uploads/prod/2019/10/RNNT.pdf
for (int t = 0; t < T; ++t) {
for (int u = 0; u < U; ++u) {
CAST_DTYPE c = alpha({t, u}) + cost - denom({t, u});
for (int d = 0; d < D; ++d) {
CAST_DTYPE g = CAST_DTYPE(logits({t, u, d})) + c;
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g);
} else if (d == blank && t < T - 1) {
gradients({t, u, d}) =
std::exp(g + beta({t, u})) - std::exp(g + beta({t + 1, u}));
} else if (u < U - 1 && d == targets[u]) {
gradients({t, u, d}) =
std::exp(g + beta({t, u})) - std::exp(g + beta({t, u + 1}));
} else {
gradients({t, u, d}) = std::exp(g + beta({t, u}));
}
if (clamp > 0) {
gradients({t, u, d}) =
math::min(CAST_DTYPE(gradients({t, u, d})), clamp);
gradients({t, u, d}) =
math::max(CAST_DTYPE(gradients({t, u, d})), -clamp);
}
}
}
}
// zero out the rest of the gradients, necessary when reusing logits memory
// check the memory location to see if it's necessary
if (&gradients({0, 0, 0}) == &logits({0, 0, 0})) {
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int t = T; t < maxT; ++t) {
for (int u = 0; u < maxU; ++u) {
for (int d = 0; d < D; ++d) {
gradients({t, u, d}) = 0.;
}
}
}
for (int t = 0; t < T; ++t) {
for (int u = U; u < maxU; ++u) {
for (int d = 0; d < D; ++d) {
gradients({t, u, d}) = 0.;
}
}
}
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeGradients(
const Options& options,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients) {
std::vector<TensorView<const DTYPE>> seqLogits;
std::vector<const int*> seqTargets;
std::vector<TensorView<const CAST_DTYPE>> seqDenoms;
std::vector<TensorView<const CAST_DTYPE>> seq_alphas;
std::vector<TensorView<const CAST_DTYPE>> seq_betas;
std::vector<TensorView<DTYPE>> seq_gradients;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
for (int b = 0; b < B; ++b) {
seqLogits.push_back(
TensorView<const DTYPE>({maxT, maxU, D}, logits + b * maxT * maxU * D));
seqTargets.push_back(targets + b * (maxU - 1));
seqDenoms.push_back(TensorView<const CAST_DTYPE>(
{maxT, maxU}, denominators + b * maxT * maxU));
seq_alphas.push_back(
TensorView<const CAST_DTYPE>({maxT, maxU}, alphas + b * maxT * maxU));
seq_betas.push_back(
TensorView<const CAST_DTYPE>({maxT, maxU}, betas + b * maxT * maxU));
seq_gradients.push_back(
TensorView<DTYPE>({maxT, maxU, D}, gradients + b * maxT * maxU * D));
}
//#pragma omp parallel for
for (int b = 0; b < B; ++b) { // use max 2 * B threads.
ComputeGradientsOneSequence<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/seqLogits[b],
/*targets=*/seqTargets[b],
/*srcLen=*/srcLengths[b],
/*tgtLen=*/tgtLengths[b] + 1, // with prepended blank.
/*denom=*/seqDenoms[b],
/*alpha=*/seq_alphas[b],
/*beta=*/seq_betas[b],
/*gradients=*/seq_gradients[b]);
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeAlphas(
const Options& options,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
CAST_DTYPE* alphas) {
std::vector<TensorView<const LogProbs<CAST_DTYPE>>> seqlogProbs;
std::vector<TensorView<CAST_DTYPE>> seq_alphas;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int b = 0; b < B; ++b) {
seqlogProbs.push_back(TensorView<const LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(
const_cast<CAST_DTYPE*>(logProbs)) +
b * maxT * maxU));
seq_alphas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, alphas + b * maxT * maxU));
}
std::vector<CAST_DTYPE> scores(B << 1);
//#pragma omp parallel for
for (int i = 0; i < B; ++i) { // use max 2 * B threads.
ComputeAlphaOneSequence<DTYPE>(
options,
/*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
/*alpha=*/seq_alphas[i]);
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeBetas(
const Options& options,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
CAST_DTYPE* costs,
CAST_DTYPE* betas) {
std::vector<TensorView<const LogProbs<CAST_DTYPE>>> seqlogProbs;
std::vector<TensorView<CAST_DTYPE>> seq_betas;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int b = 0; b < B; ++b) {
seqlogProbs.push_back(TensorView<const LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(
const_cast<CAST_DTYPE*>(logProbs)) +
b * maxT * maxU));
seq_betas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, betas + b * maxT * maxU));
}
std::vector<CAST_DTYPE> scores(B << 1);
//#pragma omp parallel for
for (int i = 0; i < B; ++i) { // use max 2 * B threads.
ComputeBetaOneSequence<DTYPE>(
options,
/*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
/*betas=*/seq_betas[i]);
}
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/rnnt/cpu/cpu_kernels.h>
#include <torchaudio/csrc/rnnt/workspace.h>
namespace torchaudio {
namespace rnnt {
namespace cpu {
// Inputs:
// workspace: workspace.
// logits: pointer to (B, maxT, maxU, D) logits.
// targets: pointer to (B, maxU - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, maxT, maxU, D) gradients in the batch.
template <typename DTYPE, typename CAST_DTYPE>
status_t Compute(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* gradients = nullptr) {
const Options& options = workspace.GetOptions();
CHECK_EQ(options.device_, CPU);
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
{ // compute denominators.
LogSumExp2D<DTYPE, CAST_DTYPE>(
/*N=*/B * maxT * maxU,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
}
{ // compute log prob pairs.
ComputeLogProbs<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs());
}
{ // compute alphas and betas.
ComputeAlphasBetas<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*costs=*/costs);
}
if (gradients != nullptr) {
ComputeGradients<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients);
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeAlphas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* alphas) {
const Options& options = workspace.GetOptions();
CHECK_EQ(options.device_, CPU);
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
{ // compute denominators.
LogSumExp2D<DTYPE, CAST_DTYPE>(
/*N=*/B * maxT * maxU,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
}
{ // compute log prob pairs.
ComputeLogProbs<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs());
}
{ // compute alphas.
ComputeAlphas<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alphas=*/alphas);
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeBetas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* betas) {
const Options& options = workspace.GetOptions();
CHECK_EQ(options.device_, CPU);
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
{ // compute denominators.
LogSumExp2D<DTYPE, CAST_DTYPE>(
/*N=*/B * maxT * maxU,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
}
{ // compute log prob pairs.
ComputeLogProbs<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs());
}
{ // compute betas.
ComputeBetas<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*betas=*/betas);
}
return SUCCESS;
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/cpu/math.h>
namespace torchaudio {
namespace rnnt {
inline HOST_AND_DEVICE bool in_range(
int start,
int end, // inclusive
int val) {
return start <= val && val <= end;
}
#define LOG_PROBS_SKIP_IDX 0
#define LOG_PROBS_EMIT_IDX 1
struct Indexer2D {
const int& size2_;
FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {}
FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) {
return index1 * size2_ + index2;
}
};
struct Indexer3D {
const int& size2_;
const int& size3_;
FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3)
: size2_(size2), size3_(size3) {}
FORCE_INLINE HOST_AND_DEVICE int operator()(
int index1,
int index2,
int index3) {
return (index1 * size2_ + index2) * size3_ + index3;
}
};
struct Indexer4D {
const int& size2_;
const int& size3_;
const int& size4_;
HOST_AND_DEVICE Indexer4D(
const int& size2,
const int& size3,
const int& size4)
: size2_(size2), size3_(size3), size4_(size4) {}
HOST_AND_DEVICE int operator()(
int index1,
int index2,
int index3,
int index4) {
return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4;
}
};
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/rnnt/macros.h>
namespace torchaudio {
namespace rnnt {
namespace math {
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) {
if (x > y)
return x;
else
return y;
}
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) {
if (x > y)
return y;
else
return x;
}
// log_sum_exp
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y);
template <>
FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) {
if (y > x) {
return y + log1pf(expf(x - y));
} else {
return x + log1pf(expf(y - x));
}
}
} // namespace math
} // namespace rnnt
} // namespace torchaudio
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace gpu {
// Entry point into RNNT Loss
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
TORCH_CHECK(
logits.device().type() == logit_lengths.device().type(),
"logits and logit_lengths must be on the same device");
TORCH_CHECK(
logits.device().type() == target_lengths.device().type(),
"logits and target_lengths must be on the same device");
TORCH_CHECK(
logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
"logits must be float32 or float16 (half) type");
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
TORCH_CHECK(
logit_lengths.dtype() == torch::kInt32,
"logit_lengths must be int32 type");
TORCH_CHECK(
target_lengths.dtype() == torch::kInt32,
"target_lengths must be int32 type");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(
target_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK(
logit_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths");
TORCH_CHECK(
target_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths");
TORCH_CHECK(
targets.size(0) == logits.size(0),
"batch dimension mismatch between logits and targets");
TORCH_CHECK(
blank >= 0 && blank < logits.size(-1),
"blank must be within [0, logits.shape[-1])");
TORCH_CHECK(
logits.size(1) == at::max(logit_lengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
logits.size(2) == at::max(target_lengths).item().toInt() + 1,
"output length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(target_lengths).item().toInt(),
"target length mismatch");
Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
cudaSetDevice(logits.get_device());
options.device_ = GPU;
torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = torch::zeros_like(logits);
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
switch (logits.scalar_type()) {
case torch::ScalarType::Float: {
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/gradients->data_ptr<float>());
break;
}
case torch::ScalarType::Half: {
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/gradients->data_ptr<c10::Half>());
break;
}
default: {
break;
}
};
return std::make_tuple(costs, gradients);
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss", &compute);
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace gpu {
torch::Tensor compute_alphas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
cudaSetDevice(logits.get_device());
options.device_ = GPU;
torch::Tensor alphas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>());
return alphas;
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss_alphas", &compute_alphas);
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace gpu {
torch::Tensor compute_betas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
cudaSetDevice(logits.get_device());
options.device_ = GPU;
torch::Tensor costs = torch::empty(
target_lengths.size(0),
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor betas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*betas=*/betas.data_ptr<float>());
return betas;
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss_betas", &compute_betas);
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#pragma once
#ifdef USE_CUDA
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace torchaudio {
namespace rnnt {
template <int NUM_THREADS, typename DTYPE, typename CAST_DTYPE>
__global__ void ReduceMax2D(
int dim,
const DTYPE* inputs, // [N, dim]
CAST_DTYPE* outputs) {
__shared__ CAST_DTYPE shared[NUM_THREADS];
// each thread reduces one matrix row
int offset = blockIdx.x * dim; // [n, 0]
CAST_DTYPE val = inputs[offset]; // default = inputs(n, 0)
for (int d = threadIdx.x; d < dim; d += NUM_THREADS) {
CAST_DTYPE next = inputs[offset + d];
if (next > val) {
val = next;
}
}
shared[threadIdx.x] = val;
__syncthreads();
for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) {
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
if (shared[threadIdx.x + stride] > shared[threadIdx.x]) {
shared[threadIdx.x] = shared[threadIdx.x + stride];
val = shared[threadIdx.x];
}
}
__syncthreads();
}
CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
if (shf > val) {
val = shf;
}
}
}
if (threadIdx.x == 0) {
outputs[blockIdx.x] = val;
}
}
template <int NUM_THREADS, typename DTYPE, typename CAST_DTYPE>
__global__ void ReduceLogSumExpGivenMax2D(
int dim,
const DTYPE* inputs, // [N, dim]
CAST_DTYPE* outputs) { // in: max -> out: logsum
__shared__ CAST_DTYPE shared[NUM_THREADS];
CAST_DTYPE max = outputs[blockIdx.x];
CAST_DTYPE val = 0;
int offset = blockIdx.x * dim;
for (int d = threadIdx.x; d < dim; d += NUM_THREADS) {
val = val + std::exp(CAST_DTYPE(inputs[offset + d]) - max);
}
shared[threadIdx.x] = val;
__syncthreads();
for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) {
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
val = shared[threadIdx.x] + shared[threadIdx.x + stride];
shared[threadIdx.x] = val;
}
__syncthreads();
}
CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
val = val + shf;
}
}
if (threadIdx.x == 0) {
outputs[blockIdx.x] = max + std::log(val);
}
}
} // namespace rnnt
} // namespace torchaudio
#endif // USE_CUDA
#pragma once
#ifdef USE_CUDA
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/gpu/kernels.h>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace torchaudio {
namespace rnnt {
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeLogProbs(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
const int bTgt = blockIdx.z; // 0 <= b < B
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y;
if (t >= T || u >= U) { // out of boundary.
return;
}
Indexer3D indexer(maxT, maxU);
int idx = indexer(bTgt, t, u);
// skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u).
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]) - denominators[idx];
if (u < U - 1) {
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int target = targets[Indexer2D(maxU - 1)(bTgt, u)];
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]) - denominators[idx];
}
}
template <typename DTYPE, typename CAST_DTYPE>
__device__ void ComputeAlphas(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* alpha_counters,
volatile CAST_DTYPE* alphas,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int bTgt = blockIdx.z; // 0 <= b < B
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
const int t = blockIdx.x * blockDim.x + threadIdx.x + 1;
const int u = blockIdx.y + 1;
if (t >= T || u >= U) { // out of boundary.
return;
}
int* counter = alpha_counters + Indexer2D(maxU)(bTgt, blockIdx.y);
Indexer3D idxr(maxT, maxU);
if (t == 1 && u == 1) {
alphas[idxr(bTgt, 0, 0)] = 0;
}
if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready.
while (atomicAdd(counter, 0) < blockIdx.x) {
}
}
if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready.
while (atomicAdd(counter - 1, 0) <= blockIdx.x) {
}
}
if (t == 1 && u < U) {
// alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit().
alphas[idxr(bTgt, 0, u)] = alphas[idxr(bTgt, 0, u - 1)] +
logProbs[(idxr(bTgt, 0, u - 1) << 1) + LOG_PROBS_EMIT_IDX];
}
if (blockIdx.y == 0 && t < T) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t - 1, 0) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE val;
#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
val = __shfl_up_sync(0xffffffff, skip_prob, i);
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
}
val = alphas[idxr(bTgt, blockIdx.x * blockDim.x, 0)];
alphas[idxr(bTgt, t, 0)] = skip_prob + val;
}
if (t < T && u < U) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t - 1, u) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE emit_prob =
logProbs[(idxr(bTgt, t, u - 1) << 1) + LOG_PROBS_EMIT_IDX];
CAST_DTYPE skip =
alphas[idxr(bTgt, blockIdx.x * blockDim.x, u)] + skip_prob;
CAST_DTYPE emit = alphas[idxr(bTgt, t, u - 1)] + emit_prob;
CAST_DTYPE val = math::lse(skip, emit);
CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) {
val = __shfl_up_sync(0xffffffff, val, 1);
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
}
}
alphas[idxr(bTgt, t, u)] = out;
}
if (threadIdx.x == 0) {
__threadfence();
atomicAdd(counter, 1);
}
}
template <typename DTYPE, typename CAST_DTYPE>
__device__ void ComputeBetasCosts(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* betaCounters,
volatile CAST_DTYPE* betas,
DTYPE* costs,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int bTgt = blockIdx.z; // 0 <= b < B
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
const int t = T - 2 - blockIdx.x * blockDim.x - threadIdx.x;
const int u = U - 2 - blockIdx.y;
if (t < 0 || u < 0) { // out of boundary.
return;
}
int* counter = betaCounters + Indexer2D(maxU)(bTgt, blockIdx.y);
Indexer3D idxr(maxT, maxU);
if (t == T - 2 && u == U - 2) {
betas[idxr(bTgt, T - 1, U - 1)] =
logProbs[(idxr(bTgt, T - 1, U - 1) << 1) + LOG_PROBS_SKIP_IDX];
}
if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready.
while (atomicAdd(counter, 0) < blockIdx.x) {
}
}
if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready.
while (atomicAdd(counter - 1, 0) <= blockIdx.x) {
}
}
if (t == T - 2 && u >= 0) {
betas[idxr(bTgt, T - 1, u)] = betas[idxr(bTgt, T - 1, u + 1)] +
logProbs[(idxr(bTgt, T - 1, u) << 1) + LOG_PROBS_EMIT_IDX];
}
if (blockIdx.y == 0 && t >= 0) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t, U - 1) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE val;
#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
val = __shfl_up_sync(0xffffffff, skip_prob, i);
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
}
betas[idxr(bTgt, t, U - 1)] =
betas[idxr(bTgt, T - 1 - blockIdx.x * blockDim.x, U - 1)] + skip_prob;
}
if (t >= 0 && u >= 0) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE emit_prob =
logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_EMIT_IDX];
CAST_DTYPE skip = betas[idxr(bTgt, t + threadIdx.x + 1, u)] + skip_prob;
CAST_DTYPE emit = betas[idxr(bTgt, t, u + 1)] + emit_prob;
CAST_DTYPE val = math::lse(skip, emit);
CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) {
val = __shfl_up_sync(0xffffffff, val, 1);
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
}
}
betas[idxr(bTgt, t, u)] = out;
if (t == 0 && u == 0) { // use -beta(0, 0) as cost.
costs[bTgt] = DTYPE(-out);
}
}
if (threadIdx.x == 0) {
__threadfence();
atomicAdd(counter, 1);
}
}
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeAlphasBetasCosts(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* alpha_counters,
volatile CAST_DTYPE* alphas,
int* betaCounters,
volatile CAST_DTYPE* betas,
DTYPE* costs,
int warpSize = 0,
int numWarps = 0,
int H = 1) {
assert(threadIdx.y == 0 || threadIdx.y == 1);
if (threadIdx.y == 0) {
ComputeAlphas<DTYPE, CAST_DTYPE>(
/*maxSrcLen=*/maxSrcLen,
/*maxTgtLen=*/maxTgtLen,
/*numTargets=*/numTargets,
/*blank=*/blank,
/*logProbs=*/logProbs,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/alpha_counters,
/*alphas=*/alphas,
H);
} else { // threadIdx.y == 1
ComputeBetasCosts<DTYPE, CAST_DTYPE>(
/*maxSrcLen=*/maxSrcLen,
/*maxTgtLen=*/maxTgtLen,
/*numTargets=*/numTargets,
/*blank=*/blank,
/*logProbs=*/logProbs,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*betaCounters=*/betaCounters,
/*beta=*/betas,
/*costs=*/costs,
H);
}
}
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeGradients(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
CAST_DTYPE clamp,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1) {
const int bTgt = blockIdx.z; // 0 <= b < B
const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y;
ComputeGradientsElement(
bTgt,
t,
u,
maxSrcLen,
maxTgtLen,
numTargets,
blank,
clamp,
logits,
targets,
srcLengths,
tgtLengths,
denominators,
alphas,
betas,
gradients,
H);
}
// This is a __global__ wrapper around ComputeAlphas
// device kernel to enable unit testing
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeAlphasWrapper(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* alpha_counters,
volatile CAST_DTYPE* alphas,
int H = 1) {
ComputeAlphas<DTYPE, CAST_DTYPE>(
maxSrcLen,
maxTgtLen,
numTargets,
blank,
logProbs,
srcLengths,
tgtLengths,
alpha_counters,
alphas,
H);
}
// This is a __global__ wrapper around ComputeBetas
// device kernel to enable unit testing
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeBetasWrapper(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* betaCounters,
volatile CAST_DTYPE* betas,
DTYPE* costs,
int H = 1) {
ComputeBetasCosts<DTYPE, CAST_DTYPE>(
maxSrcLen,
maxTgtLen,
numTargets,
blank,
logProbs,
srcLengths,
tgtLengths,
betaCounters,
betas,
costs,
H);
}
// #undef LOG_PROBS_SKIP_IDX
// #undef LOG_PROBS_EMIT_IDX
} // namespace rnnt
} // namespace torchaudio
#endif // USE_CUDA
#pragma once
#ifdef USE_CUDA
#include <torchaudio/csrc/rnnt/workspace.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh>
namespace torchaudio {
namespace rnnt {
namespace gpu {
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(
cudaError_t code,
const char* file,
int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(
stderr,
"\nGPUassert: %s %s %d\n",
cudaGetErrorString(code),
file,
line);
if (abort)
exit(code);
}
}
template <typename DTYPE, typename CAST_DTYPE>
status_t LogSumExp2D(
cudaStream_t stream,
int N,
int D,
const DTYPE* logits, // [N, D]
CAST_DTYPE* outputs) {
{ // compute max among D.
dim3 block_dims(N);
dim3 thread_dims(REDUCE_THREADS);
ReduceMax2D<REDUCE_THREADS, DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*dim=*/D,
/*inputs=*/logits,
/*outputs=*/outputs);
// BUGBUG: These error codes are only accurate when launching with
// blocking. Otherwise they usually reflect earlier errors.
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED;
}
}
{ // compute log(sum(exp(d_i - max)))
dim3 block_dims(N);
dim3 thread_dims(REDUCE_THREADS);
ReduceLogSumExpGivenMax2D<REDUCE_THREADS, DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*dim=*/D,
/*inputs=*/logits,
/*outputs=*/outputs);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED;
}
}
return SUCCESS;
}
// Inputs:
// workspace: workspace.
// logits: pointer to (B, max_T, max_U, D) logits.
// targets: pointer to (B, max_U - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, max_T, max_U, D) gradients in the batch.
template <typename DTYPE, typename CAST_DTYPE>
status_t Compute(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* gradients = nullptr) {
const Options& options = workspace.GetOptions();
const cudaStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute alphas, betas and costs.
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B * H blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 2. 1 for alpha, 1 for beta
dim3 thread_dims(WARP_SIZE, 2);
ComputeAlphasBetasCosts<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToAlphaCounters(),
/*alphas=*/workspace.GetPointerToAlphas(),
/*beta_counters=*/workspace.GetPointerToBetaCounters(),
/*betas=*/workspace.GetPointerToBetas(),
/*costs=*/costs,
/*warp_size=*/WARP_SIZE,
/*num_warps=*/num_warps,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
if (gradients != nullptr) { // compute gradients.
// don't set gradients to zero to here as gradients might reuse memory from
// logits
int num_blocks =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_blocks, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeGradients<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*clamp=*/clamp,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_GRADIENTS_FAILED;
}
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeAlphas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* alphas) {
const Options& options = workspace.GetOptions();
const cudaStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute alphas
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for alpha only
dim3 thread_dims(WARP_SIZE, 1);
ComputeAlphasWrapper<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToAlphaCounters(),
/*alphas=*/(volatile DTYPE*)alphas,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeBetas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* betas) {
const Options& options = workspace.GetOptions();
const cudaStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute betas
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for betas only
dim3 thread_dims(WARP_SIZE, 1);
ComputeBetasWrapper<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToBetaCounters(),
/*alphas=*/(volatile DTYPE*)betas,
costs,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
return SUCCESS;
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#endif // USE_CUDA
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment