Unverified Commit 5417e4fb authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Add GPU RNNT Loss (#1483)

parent 7d45851d
...@@ -59,6 +59,11 @@ option(BUILD_KALDI "Build kaldi statically" ON) ...@@ -59,6 +59,11 @@ option(BUILD_KALDI "Build kaldi statically" ON)
option(BUILD_TRANSDUCER "Enable transducer" OFF) option(BUILD_TRANSDUCER "Enable transducer" OFF)
option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON) option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON)
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF) option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)
option(USE_CUDA "Enable CUDA support" OFF)
if(USE_CUDA)
enable_language(CUDA)
endif()
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
......
...@@ -38,6 +38,7 @@ _BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX" ...@@ -38,6 +38,7 @@ _BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX"
_BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True) _BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True)
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER") _BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
_USE_ROCM = _get_build("USE_ROCM") _USE_ROCM = _get_build("USE_ROCM")
_USE_CUDA = torch.cuda.is_available()
def get_ext_modules(): def get_ext_modules():
...@@ -76,6 +77,7 @@ class CMakeBuild(build_ext): ...@@ -76,6 +77,7 @@ class CMakeBuild(build_ext):
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON", "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF", "-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}", f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}",
f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}",
] ]
build_args = [ build_args = [
'--target', 'install' '--target', 'install'
......
import torch
from .rnnt_loss_impl import RNNTLossTest
from torchaudio_unittest import common_utils
from .utils import skipIfNoTransducer
@skipIfNoTransducer
@common_utils.skipIfNoCuda
class TestRNNTLoss(RNNTLossTest, common_utils.PytorchTestCase):
device = torch.device('cuda')
...@@ -20,6 +20,17 @@ if(BUILD_TRANSDUCER) ...@@ -20,6 +20,17 @@ if(BUILD_TRANSDUCER)
rnnt/compute_betas.cpp rnnt/compute_betas.cpp
rnnt/compute.cpp rnnt/compute.cpp
) )
if (USE_CUDA)
set(
CUDA_TRANSDUCER_SOURCES
rnnt/gpu/compute_alphas.cu
rnnt/gpu/compute_betas.cu
rnnt/gpu/compute.cu
)
list(APPEND TRANSDUCER_SOURCES ${CUDA_TRANSDUCER_SOURCES})
endif()
list(APPEND LIBTORCHAUDIO_SOURCES ${TRANSDUCER_SOURCES}) list(APPEND LIBTORCHAUDIO_SOURCES ${TRANSDUCER_SOURCES})
endif() endif()
...@@ -105,6 +116,10 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION) ...@@ -105,6 +116,10 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
target_compile_definitions(_torchaudio PRIVATE INCLUDE_KALDI) target_compile_definitions(_torchaudio PRIVATE INCLUDE_KALDI)
endif() endif()
if (USE_CUDA)
target_compile_definitions(_torchaudio PRIVATE USE_CUDA)
endif()
target_include_directories( target_include_directories(
_torchaudio _torchaudio
PRIVATE PRIVATE
......
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace gpu {
// Entry point into RNNT Loss
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_smax;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
cudaSetDevice(logits.get_device());
options.device_ = GPU;
torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) {
if (reuse_logits_for_grads) {
gradients = logits;
} else {
gradients = torch::zeros_like(logits);
}
}
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
switch (logits.scalar_type()) {
case torch::ScalarType::Float: {
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
break;
}
case torch::ScalarType::Half: {
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr
: gradients->data_ptr<c10::Half>());
break;
}
default: {
break;
}
};
return std::make_tuple(costs, gradients);
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss", &compute);
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace gpu {
torch::Tensor compute_alphas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
cudaSetDevice(logits.get_device());
options.device_ = GPU;
torch::Tensor alphas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>());
return alphas;
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss_alphas", &compute_alphas);
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace gpu {
torch::Tensor compute_betas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
cudaSetDevice(logits.get_device());
options.device_ = GPU;
torch::Tensor costs = torch::empty(
tgt_lengths.size(0),
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor betas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*betas=*/betas.data_ptr<float>());
return betas;
}
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss_betas", &compute_betas);
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#pragma once
#ifdef USE_CUDA
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace torchaudio {
namespace rnnt {
template <int NUM_THREADS, typename DTYPE, typename CAST_DTYPE>
__global__ void ReduceMax2D(
int dim,
const DTYPE* inputs, // [N, dim]
CAST_DTYPE* outputs) {
__shared__ CAST_DTYPE shared[NUM_THREADS];
// each thread reduces one matrix row
int offset = blockIdx.x * dim; // [n, 0]
CAST_DTYPE val = inputs[offset]; // default = inputs(n, 0)
for (int d = threadIdx.x; d < dim; d += NUM_THREADS) {
CAST_DTYPE next = inputs[offset + d];
if (next > val) {
val = next;
}
}
shared[threadIdx.x] = val;
__syncthreads();
for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) {
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
if (shared[threadIdx.x + stride] > shared[threadIdx.x]) {
shared[threadIdx.x] = shared[threadIdx.x + stride];
val = shared[threadIdx.x];
}
}
__syncthreads();
}
CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
if (shf > val) {
val = shf;
}
}
}
if (threadIdx.x == 0) {
outputs[blockIdx.x] = val;
}
}
template <int NUM_THREADS, typename DTYPE, typename CAST_DTYPE>
__global__ void ReduceLogSumExpGivenMax2D(
int dim,
const DTYPE* inputs, // [N, dim]
CAST_DTYPE* outputs) { // in: max -> out: logsum
__shared__ CAST_DTYPE shared[NUM_THREADS];
CAST_DTYPE max = outputs[blockIdx.x];
CAST_DTYPE val = 0;
int offset = blockIdx.x * dim;
for (int d = threadIdx.x; d < dim; d += NUM_THREADS) {
val = val + std::exp(CAST_DTYPE(inputs[offset + d]) - max);
}
shared[threadIdx.x] = val;
__syncthreads();
for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) {
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
val = shared[threadIdx.x] + shared[threadIdx.x + stride];
shared[threadIdx.x] = val;
}
__syncthreads();
}
CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
val = val + shf;
}
}
if (threadIdx.x == 0) {
outputs[blockIdx.x] = max + std::log(val);
}
}
} // namespace rnnt
} // namespace torchaudio
#endif // USE_CUDA
#pragma once
#ifdef USE_CUDA
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/gpu/kernels.h>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace torchaudio {
namespace rnnt {
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeLogProbs(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs,
int H = 1,
bool fusedLogSmax = true) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
const int bTgt = blockIdx.z; // 0 <= b < B
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y;
if (t >= T || u >= U) { // out of boundary.
return;
}
Indexer3D indexer(maxT, maxU);
int idx = indexer(bTgt, t, u);
// skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u).
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]) - denominators[idx];
if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]);
}
if (u < U - 1) {
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int target = targets[Indexer2D(maxU - 1)(bTgt, u)];
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]) - denominators[idx];
if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]);
}
}
}
template <typename DTYPE, typename CAST_DTYPE>
__device__ void ComputeAlphas(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* alpha_counters,
volatile CAST_DTYPE* alphas,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int bTgt = blockIdx.z; // 0 <= b < B
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
const int t = blockIdx.x * blockDim.x + threadIdx.x + 1;
const int u = blockIdx.y + 1;
if (t >= T || u >= U) { // out of boundary.
return;
}
int* counter = alpha_counters + Indexer2D(maxU)(bTgt, blockIdx.y);
Indexer3D idxr(maxT, maxU);
if (t == 1 && u == 1) {
alphas[idxr(bTgt, 0, 0)] = 0;
}
if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready.
while (atomicAdd(counter, 0) < blockIdx.x) {
}
}
if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready.
while (atomicAdd(counter - 1, 0) <= blockIdx.x) {
}
}
if (t == 1 && u < U) {
// alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit().
alphas[idxr(bTgt, 0, u)] = alphas[idxr(bTgt, 0, u - 1)] +
logProbs[(idxr(bTgt, 0, u - 1) << 1) + LOG_PROBS_EMIT_IDX];
}
if (blockIdx.y == 0 && t < T) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t - 1, 0) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE val;
#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
val = __shfl_up_sync(0xffffffff, skip_prob, i);
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
}
val = alphas[idxr(bTgt, blockIdx.x * blockDim.x, 0)];
alphas[idxr(bTgt, t, 0)] = skip_prob + val;
}
if (t < T && u < U) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t - 1, u) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE emit_prob =
logProbs[(idxr(bTgt, t, u - 1) << 1) + LOG_PROBS_EMIT_IDX];
CAST_DTYPE skip =
alphas[idxr(bTgt, blockIdx.x * blockDim.x, u)] + skip_prob;
CAST_DTYPE emit = alphas[idxr(bTgt, t, u - 1)] + emit_prob;
CAST_DTYPE val = math::lse(skip, emit);
CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) {
val = __shfl_up_sync(0xffffffff, val, 1);
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
}
}
alphas[idxr(bTgt, t, u)] = out;
}
if (threadIdx.x == 0) {
__threadfence();
atomicAdd(counter, 1);
}
}
template <typename DTYPE, typename CAST_DTYPE>
__device__ void ComputeBetasCosts(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* betaCounters,
volatile CAST_DTYPE* betas,
DTYPE* costs,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int bTgt = blockIdx.z; // 0 <= b < B
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
const int t = T - 2 - blockIdx.x * blockDim.x - threadIdx.x;
const int u = U - 2 - blockIdx.y;
if (t < 0 || u < 0) { // out of boundary.
return;
}
int* counter = betaCounters + Indexer2D(maxU)(bTgt, blockIdx.y);
Indexer3D idxr(maxT, maxU);
if (t == T - 2 && u == U - 2) {
betas[idxr(bTgt, T - 1, U - 1)] =
logProbs[(idxr(bTgt, T - 1, U - 1) << 1) + LOG_PROBS_SKIP_IDX];
}
if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready.
while (atomicAdd(counter, 0) < blockIdx.x) {
}
}
if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready.
while (atomicAdd(counter - 1, 0) <= blockIdx.x) {
}
}
if (t == T - 2 && u >= 0) {
betas[idxr(bTgt, T - 1, u)] = betas[idxr(bTgt, T - 1, u + 1)] +
logProbs[(idxr(bTgt, T - 1, u) << 1) + LOG_PROBS_EMIT_IDX];
}
if (blockIdx.y == 0 && t >= 0) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t, U - 1) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE val;
#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
val = __shfl_up_sync(0xffffffff, skip_prob, i);
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
}
betas[idxr(bTgt, t, U - 1)] =
betas[idxr(bTgt, T - 1 - blockIdx.x * blockDim.x, U - 1)] + skip_prob;
}
if (t >= 0 && u >= 0) {
CAST_DTYPE skip_prob =
logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_SKIP_IDX];
CAST_DTYPE emit_prob =
logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_EMIT_IDX];
CAST_DTYPE skip = betas[idxr(bTgt, t + threadIdx.x + 1, u)] + skip_prob;
CAST_DTYPE emit = betas[idxr(bTgt, t, u + 1)] + emit_prob;
CAST_DTYPE val = math::lse(skip, emit);
CAST_DTYPE out = val;
for (int i = 1; i < warpSize; ++i) {
val = __shfl_up_sync(0xffffffff, val, 1);
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
}
}
betas[idxr(bTgt, t, u)] = out;
if (t == 0 && u == 0) { // use -beta(0, 0) as cost.
costs[bTgt] = DTYPE(-out);
}
}
if (threadIdx.x == 0) {
__threadfence();
atomicAdd(counter, 1);
}
}
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeAlphasBetasCosts(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* alpha_counters,
volatile CAST_DTYPE* alphas,
int* betaCounters,
volatile CAST_DTYPE* betas,
DTYPE* costs,
int warpSize = 0,
int numWarps = 0,
int H = 1) {
assert(threadIdx.y == 0 || threadIdx.y == 1);
if (threadIdx.y == 0) {
ComputeAlphas<DTYPE, CAST_DTYPE>(
/*maxSrcLen=*/maxSrcLen,
/*maxTgtLen=*/maxTgtLen,
/*numTargets=*/numTargets,
/*blank=*/blank,
/*logProbs=*/logProbs,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/alpha_counters,
/*alphas=*/alphas,
H);
} else { // threadIdx.y == 1
ComputeBetasCosts<DTYPE, CAST_DTYPE>(
/*maxSrcLen=*/maxSrcLen,
/*maxTgtLen=*/maxTgtLen,
/*numTargets=*/numTargets,
/*blank=*/blank,
/*logProbs=*/logProbs,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*betaCounters=*/betaCounters,
/*beta=*/betas,
/*costs=*/costs,
H);
}
}
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeGradients(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
CAST_DTYPE clamp,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1,
bool fusedLogSmax = true) {
const int bTgt = blockIdx.z; // 0 <= b < B
const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y;
ComputeGradientsElement(
bTgt,
t,
u,
maxSrcLen,
maxTgtLen,
numTargets,
blank,
clamp,
logits,
targets,
srcLengths,
tgtLengths,
denominators,
alphas,
betas,
gradients,
H,
fusedLogSmax);
}
// This is a __global__ wrapper around ComputeAlphas
// device kernel to enable unit testing
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeAlphasWrapper(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* alpha_counters,
volatile CAST_DTYPE* alphas,
int H = 1) {
ComputeAlphas<DTYPE, CAST_DTYPE>(
maxSrcLen,
maxTgtLen,
numTargets,
blank,
logProbs,
srcLengths,
tgtLengths,
alpha_counters,
alphas,
H);
}
// This is a __global__ wrapper around ComputeBetas
// device kernel to enable unit testing
template <typename DTYPE, typename CAST_DTYPE>
__global__ void ComputeBetasWrapper(
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
int* betaCounters,
volatile CAST_DTYPE* betas,
DTYPE* costs,
int H = 1) {
ComputeBetasCosts<DTYPE, CAST_DTYPE>(
maxSrcLen,
maxTgtLen,
numTargets,
blank,
logProbs,
srcLengths,
tgtLengths,
betaCounters,
betas,
costs,
H);
}
// #undef LOG_PROBS_SKIP_IDX
// #undef LOG_PROBS_EMIT_IDX
} // namespace rnnt
} // namespace torchaudio
#endif // USE_CUDA
#pragma once
#ifdef USE_CUDA
#include <torchaudio/csrc/rnnt/workspace.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh>
namespace torchaudio {
namespace rnnt {
namespace gpu {
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(
cudaError_t code,
const char* file,
int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(
stderr,
"\nGPUassert: %s %s %d\n",
cudaGetErrorString(code),
file,
line);
if (abort)
exit(code);
}
}
template <typename DTYPE, typename CAST_DTYPE>
status_t LogSumExp2D(
cudaStream_t stream,
int N,
int D,
const DTYPE* logits, // [N, D]
CAST_DTYPE* outputs) {
{ // compute max among D.
dim3 block_dims(N);
dim3 thread_dims(REDUCE_THREADS);
ReduceMax2D<REDUCE_THREADS, DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*dim=*/D,
/*inputs=*/logits,
/*outputs=*/outputs);
// BUGBUG: These error codes are only accurate when launching with
// blocking. Otherwise they usually reflect earlier errors.
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED;
}
}
{ // compute log(sum(exp(d_i - max)))
dim3 block_dims(N);
dim3 thread_dims(REDUCE_THREADS);
ReduceLogSumExpGivenMax2D<REDUCE_THREADS, DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*dim=*/D,
/*inputs=*/logits,
/*outputs=*/outputs);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED;
}
}
return SUCCESS;
}
// Inputs:
// workspace: workspace.
// logits: pointer to (B, max_T, max_U, D) logits.
// targets: pointer to (B, max_U - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, max_T, max_U, D) gradients in the batch.
template <typename DTYPE, typename CAST_DTYPE>
status_t Compute(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* gradients = nullptr) {
const Options& options = workspace.GetOptions();
const cudaStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
const bool& fusedLogSmax = options.fusedLogSmax_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H,
fusedLogSmax);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute alphas, betas and costs.
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B * H blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 2. 1 for alpha, 1 for beta
dim3 thread_dims(WARP_SIZE, 2);
ComputeAlphasBetasCosts<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToAlphaCounters(),
/*alphas=*/workspace.GetPointerToAlphas(),
/*beta_counters=*/workspace.GetPointerToBetaCounters(),
/*betas=*/workspace.GetPointerToBetas(),
/*costs=*/costs,
/*warp_size=*/WARP_SIZE,
/*num_warps=*/num_warps,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
if (gradients != nullptr) { // compute gradients.
// don't set gradients to zero to here as gradients might reuse memory from
// logits
int num_blocks =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_blocks, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeGradients<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*clamp=*/clamp,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients,
H,
fusedLogSmax);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_GRADIENTS_FAILED;
}
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeAlphas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* alphas) {
const Options& options = workspace.GetOptions();
const cudaStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute alphas
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for alpha only
dim3 thread_dims(WARP_SIZE, 1);
ComputeAlphasWrapper<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToAlphaCounters(),
/*alphas=*/(volatile DTYPE*)alphas,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeBetas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* betas) {
const Options& options = workspace.GetOptions();
const cudaStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute betas
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for betas only
dim3 thread_dims(WARP_SIZE, 1);
ComputeBetasWrapper<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToBetaCounters(),
/*alphas=*/(volatile DTYPE*)betas,
costs,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
return SUCCESS;
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#endif // USE_CUDA
#pragma once
#ifdef USE_C10_HALF
#include "c10/util/Half.h"
#endif // USE_C10_HALF
#include <torchaudio/csrc/rnnt/macros.h>
namespace torchaudio {
namespace rnnt {
struct alignas(sizeof(__half)) Half {
__half x;
HOST_AND_DEVICE Half() = default;
FORCE_INLINE HOST_AND_DEVICE Half(float f) {
x = __float2half_rn(f);
if (isinf(__half2float(x))) {
x = __float2half_rz(f); // round toward 0.
}
}
FORCE_INLINE HOST_AND_DEVICE operator float() const {
return __half2float(x);
}
FORCE_INLINE HOST_AND_DEVICE Half(__half f) {
x = f;
}
FORCE_INLINE HOST_AND_DEVICE operator __half() const {
return x;
}
};
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace torchaudio {
namespace rnnt {
inline HOST_AND_DEVICE bool in_range(
int start,
int end, // inclusive
int val) {
return start <= val && val <= end;
}
#define LOG_PROBS_SKIP_IDX 0
#define LOG_PROBS_EMIT_IDX 1
struct Indexer2D {
const int& size2_;
FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {}
FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) {
return index1 * size2_ + index2;
}
};
struct Indexer3D {
const int& size2_;
const int& size3_;
FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3)
: size2_(size2), size3_(size3) {}
FORCE_INLINE HOST_AND_DEVICE int operator()(
int index1,
int index2,
int index3) {
return (index1 * size2_ + index2) * size3_ + index3;
}
};
struct Indexer4D {
const int& size2_;
const int& size3_;
const int& size4_;
HOST_AND_DEVICE Indexer4D(
const int& size2,
const int& size3,
const int& size4)
: size2_(size2), size3_(size3), size4_(size4) {}
HOST_AND_DEVICE int operator()(
int index1,
int index2,
int index3,
int index4) {
return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4;
}
};
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace torchaudio {
namespace rnnt {
template <typename DTYPE, typename CAST_DTYPE>
HOST_AND_DEVICE void ComputeGradientsElement(
int bTgt,
int t,
int u,
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
CAST_DTYPE clamp,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1,
bool fusedLogSmax = true) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
if (t >= T || u >= U) { // out of boundary.
if (gradients == logits && t < maxT && u < maxU) {
// gradients and logits are pointing to the same memory location
Indexer3D idxr3(maxT, maxU);
int idx_b_t_u_zero = idxr3(bTgt, t, u);
if (idx_b_t_u_zero != -1) {
int start = idx_b_t_u_zero * D;
for (int b_t_u_d = start; b_t_u_d < start + D; ++b_t_u_d) {
gradients[b_t_u_d] = 0;
}
}
}
return;
}
int costIdx = bTgt * maxT * maxU;
CAST_DTYPE cost = -(betas[costIdx]);
Indexer2D idxr2(maxU - 1);
int idx_b_t_u, idx_b_t_up1, idx_b_tp1_u;
Indexer3D idxr3(maxT, maxU);
idx_b_t_u = idxr3(bTgt, t, u);
idx_b_t_up1 = idxr3(bTgt, t, u + 1);
idx_b_tp1_u = idxr3(bTgt, t + 1, u);
if (idx_b_t_u == -1) {
return;
}
if (isinf(cost) || isnan(cost)) {
for (int d = 0; d < D; ++d) {
int b_t_u_d = idx_b_t_u * D + d;
gradients[b_t_u_d] = 0;
}
return;
}
CAST_DTYPE c = alphas[idx_b_t_u] + cost - denominators[idx_b_t_u];
for (int d = 0; d < D; ++d) {
int b_t_u_d = idx_b_t_u * D + d;
CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c;
if (fusedLogSmax) {
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
}
} else { // Non fused log softmax case
CAST_DTYPE g = cost + CAST_DTYPE(logits[b_t_u_d]);
if (d == blank && t == T - 1 && u == U - 1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u];
} else if (t < T - 1 && d == blank) {
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_tp1_u];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_t_up1];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
gradients[b_t_u_d] = -std::exp(gradients[b_t_u_d]);
}
if (clamp > 0) {
auto g = CAST_DTYPE(gradients[b_t_u_d]);
gradients[b_t_u_d] = math::min(g, clamp);
gradients[b_t_u_d] = math::max(g, -clamp);
}
}
}
} // namespace rnnt
} // namespace torchaudio
#pragma once
#ifdef USE_CUDA
#include <cmath>
#endif // USE_CUDA
#include <torchaudio/csrc/rnnt/gpu/half.cuh>
namespace torchaudio {
namespace rnnt {
namespace math {
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) {
if (x > y)
return x;
else
return y;
}
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) {
if (x > y)
return y;
else
return x;
}
// log_sum_exp
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y);
template <>
FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) {
if (y > x) {
return y + log1pf(expf(x - y));
} else {
return x + log1pf(expf(y - x));
}
}
} // namespace math
} // namespace rnnt
} // namespace torchaudio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment