"examples/vscode:/vscode.git/clone" did not exist on "071807c853d28f04f8f9b1db54cd124fd6fcda2c"
Unverified Commit d86d1b09 authored by Nan Zheng's avatar Nan Zheng Committed by GitHub
Browse files

Initial check-in of the transducer extensions (#1069)

* Initial check-in of the transducer extension.

* Added more comments to help explain the code

* Corrected minor typos

* 1. Renamed variable in tests to match the extension
2. Disabled ninja build option
parent e2083df5
#include <torch/extension.h>
#include <ATen/Functions.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
int tileSize);
std::vector<torch::Tensor> transducer_joint_cuda_backward(
torch::Tensor grad,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput);
torch::Tensor transducer_joint_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
int tileSize) {
CHECK_INPUT(f);
CHECK_INPUT(g);
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_forward(
f,
g,
fLen,
gLen,
batchOffset,
packedBatch,
opt,
packOutput,
tileSize);
}
std::vector<torch::Tensor> transducer_joint_backward(
torch::Tensor grad,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput) {
CHECK_INPUT(grad);
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_backward(
grad,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
packOutput);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)");
m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)");
}
\ No newline at end of file
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/macros/Macros.h>
#include <THC/THC.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize.
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width=C10_WARP_SIZE){
for (unsigned offset = width/2; offset > 0; offset /= 2){
x += __shfl_down_sync(0xffffffff, x, offset, width);
}
return x;
}
inline int largestPowerOfTwo(int x){
int y = 1;
while (y <= x)
y <<= 1;
return y >> 1;
}
// Helper class to calculate pointer offset that can be shared by different flavors of kernels.
// For fwd, batch offset and stride are different for packing and non-packing mode.
struct OffsetCalFwd{
__device__ __forceinline__ OffsetCalFwd(
int64_t batch,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t gLen,
int64_t hiddenSize,
bool packOutput) :
batch(batch),
batchOffset(batchOffset),
maxFLen(maxFLen),
maxGLen(maxGLen),
gLen(gLen),
hiddenSize(hiddenSize),
packOutput(packOutput)
{}
int64_t batch;
const int64_t *batchOffset;
int64_t maxFLen;
int64_t maxGLen;
int64_t gLen;
int64_t hiddenSize;
bool packOutput;
__device__ __forceinline__ int64_t getBatchOffset(){
return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize
: batch*maxFLen*maxGLen*hiddenSize;
}
__device__ __forceinline__ int64_t getStrideF(){
return packOutput ? gLen*hiddenSize : maxGLen*hiddenSize;
}
};
// Helper class to calculate pointer offset that can be shared by different flavors of kernels
// For bwd, batch offset and stride are different for packing and non-packing mode.
// The reducion is done for two input tensors. Therefore, generating two sets of offsets
// according to bwdFasterDim can lead to a unified implementation in the actual kernel.
struct OffsetCalBwd{
__device__ __forceinline__ OffsetCalBwd(
int64_t batch,
const int64_t *batchOffset,
const int *fLen,
const int *gLen,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim) :
batch(batch),
batchOffset(batchOffset),
maxFLen(maxFLen),
maxGLen(maxGLen),
fLen(fLen),
gLen(gLen),
hiddenSize(hiddenSize),
packOutput(packOutput),
bwdFasterDim(bwdFasterDim)
{}
int64_t batch;
const int64_t *batchOffset;
const int *fLen;
const int *gLen;
int64_t maxFLen;
int64_t maxGLen;
int64_t hiddenSize;
bool packOutput;
bool bwdFasterDim; // whether doing bwd on the faster moving dimension
__device__ __forceinline__ int64_t getBatchOffset(){
return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize
: batch*maxFLen*maxGLen*hiddenSize;
}
__device__ __forceinline__ int64_t getMaxXLen(){
return bwdFasterDim ? maxGLen : maxFLen;
}
__device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]){
return bwdFasterDim ? gLen[batch] : fLen[batch];
}
__device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]){
return bwdFasterDim ? fLen[batch] : gLen[batch];
}
__device__ __forceinline__ int64_t getStrideX(){
return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize);
}
__device__ __forceinline__ int64_t getStrideY(){
return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize;
}
};
// Vanila transducer joint forward kernel
// Detail of this joint function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// f is a tensor of shape [batch, T, H]
// g is a tensor of shape [batch, U, H]
// the transducer joint does
// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
// The resultant tensor is of shape [batch, T, U, H]
// Each thread block is working on one "batch" of data in the output tensor, [batch, t, u, :]
// This joint function can optionally pack the output where the output tensor with a shape of
// [B, T, U, H] is packed into [B_packed, H].
// Don't-care region (t > fLen) or (u > gLen) is removed.
// To enable packing, the starting offset for each batch need to be specified with batchOffset.
template <typename scalar_t, class OffsetCal>
__global__ void transducer_joint_forward(
const scalar_t *f,
const scalar_t *g,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
scalar_t *sum) {
const int batch = blockIdx.z;
const int t = blockIdx.y;
const int u = blockIdx.x;
const auto myFLen = fLen[batch];
const auto myGLen = gLen[batch];
OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideF = offsetCal.getStrideF();
scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize;
scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize;
scalar_t *mySum = sum + myBatchOffset + t*strideF + u * hiddenSize;
if (t < myFLen and u < myGLen){
#pragma unroll
for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){
if (h < hiddenSize){
mySum[h] = myF[h] + myG[h];
}
}
}
else if (packOutput == false and t < maxFLen and u < maxGLen){
// Need to write finite data to don't-care region because we instantiate the result tensor
// with torch::empty for performance reasons. Even though it is don't-care region, the
// contents need to be finite, otherwise could lead to NaN in WGRAD.
// In packing mode, this write is no longer necessary as we remove the don't-care region
// from the output.
// Picking -1 (over 0) here for ease of testing.
#pragma unroll
for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){
if (h < hiddenSize){
mySum[h] = -1;
}
}
}
}
// Tiled version of the joint forward kernel
// Detail of this joint function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// f is a tensor of shape [batch, T, H]
// g is a tensor of shape [batch, U, H]
// the transducer joint does
// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
// The resultant tensor is of shape [batch, T, U, H]
// Each thread is working on a tile of the shape of tileF x tileG in the result tensor.
// The input for the tile is first loaded in the register and is reused tileG and tileF times.
// This joint function can optionally pack the output where the output tensor with a shape of
// [B, T, U, H] is packed into [B_packed, H].
// Don't-care region (t > fLen) or (u > gLen) is removed.
// To enable packing, the starting offset for each batch need to be specified with batchOffset.
template <typename scalar_t, int tileF, int tileG, class OffsetCal>
__global__ void transducer_joint_tiled_forward(
const scalar_t *f,
const scalar_t *g,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
int64_t hiddenPerBlock,
bool packOutput,
scalar_t *sum) {
const int batch = blockIdx.z;
const int t = blockIdx.y * tileF;
const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;
const int u = blockIdx.x / hiddenBlock * tileG;
const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock;
const int h = threadIdx.x;
const auto myFLen = fLen[batch];
const auto myGLen = gLen[batch];
OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideF = offsetCal.getStrideF();
scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset;
scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset;
scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset;
if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){
// register buffers for tiled input reuse
scalar_t fBuffer[tileF], gBuffer[tileG];
for (int i = 0; i < tileF; ++i){
if (t + i < myFLen)
fBuffer[i] = myF[i*hiddenSize + h];
}
for (int j = 0; j < tileG; ++j){
if (u + j < myGLen)
gBuffer[j] = myG[j*hiddenSize + h];
}
#pragma unroll
for (int i = 0; i < tileF; ++i){
if (t + i < myFLen){
#pragma unroll
for (int j = 0; j < tileG; ++j){
if (u + j < myGLen)
mySum[i*strideF + j*hiddenSize + h] = fBuffer[i] + gBuffer[j];
else if (packOutput == false and u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
else if (packOutput == false and t + i < maxFLen){
// Again need to write finite data to don't-care region
#pragma unroll
for (int j = 0; j < tileG; ++j){
if (u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
}
}
else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset+h < hiddenSize){
// Only need to ensure the finity in normal mode
#pragma unroll
for (int i = 0; i < tileF; ++i){
if (t + i < maxFLen){
#pragma unroll
for (int j = 0; j < tileG; ++j){
if (u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
}
}
}
// Bwd operation (reduction) on one input tensor. Since the operation performed for the two input
// tensors are exactly the same, only one kernel is needed, and the different indexing offsets
// and strides are handled by OffsetCalBwd.
// When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
// non-packed form.
template <typename scalar_t, typename acc_t, class OffsetCal>
__device__ void transducer_joint_single_backward(
const scalar_t *grad,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim, // whether bwd on the faster moving dimension (u)
scalar_t *inGrad,
int yBlockOffset=0) {
const int batch = blockIdx.z;
// For the second input tensor, this offset need to be subtracted because the first yBlockOffset
// sets of thread blocks are for the first input tensor.
const int x = blockIdx.y-yBlockOffset;
const int hOffset = blockIdx.x*C10_WARP_SIZE;
const int wid = threadIdx.y;
const int lid = threadIdx.x;
const int numWarp = blockDim.y;
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput,
bwdFasterDim);
const auto maxXLen = offsetCal.getMaxXLen();
const auto myXLen = offsetCal.getMyXLen();
const auto myYLen = offsetCal.getMyYLen();
scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;
if (x < myXLen){
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideX = offsetCal.getStrideX();
const auto strideY = offsetCal.getStrideY();
scalar_t const *myGrad = grad + myBatchOffset + x*strideX + hOffset;
// Each warp reduces numYPerWarp "y" first
acc_t warpSum = 0;
auto numYPerWarp = (myYLen+numWarp-1)/numWarp;
for (int warpY = 0; warpY < numYPerWarp; ++warpY){
auto y = wid*numYPerWarp + warpY;
if (y < myYLen and (hOffset+lid) < hiddenSize)
warpSum += myGrad[y*strideY + lid];
}
// transpose partial sum in SMEM and reduce further using warpReduce
smem[lid*numWarp + wid] = warpSum;
__syncthreads();
auto sum = smem[wid*C10_WARP_SIZE + lid];
sum = warpReduce(sum, numWarp);
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// example of 4 warps (a, b, c, d) with 8 threads per warp
// Each warp need 8 / 4 = 2 threads to write the results.
if (hOffset+wid*C10_WARP_SIZE/numWarp+lid/numWarp < hiddenSize){
if (lid % numWarp == 0){
myInGrad[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = sum;
}
}
}
else if (wid == 0 and hOffset + lid < hiddenSize){
// Need to ensure the grad is zero for don't care region
myInGrad[lid] = 0;
}
}
// Actual bwd (reduction) kernel get launched.
// Call transducer_joint_single_backward twice on two input tensors.
// The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
// uses the rest.
template <typename scalar_t, typename acc_t, class OffsetCal>
__global__ void transducer_joint_combined_backward(
const scalar_t *grad,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
scalar_t *fGrad,
scalar_t *gGrad) {
if (blockIdx.y < maxFLen){
transducer_joint_single_backward<scalar_t, acc_t, OffsetCal>(
grad,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
false,
fGrad);
}
else{
transducer_joint_single_backward<scalar_t, acc_t, OffsetCal>(
grad,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
true,
gGrad,
maxFLen);
}
}
// Vectorized version of transducer_joint_single_backward
// Doing exact same operation as transducer_joint_single_backward except the load and store are
// vectorized.
// When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
// non-packed form.
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal>
__device__ void transducer_joint_single_vec_backward(
const scalar_t *grad,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim,
scalar_t *inGrad,
int yBlockOffset=0){
const int batch = blockIdx.z;
const int x = blockIdx.y - yBlockOffset;
const int hOffset = blockIdx.x*C10_WARP_SIZE*V;
const int wid = threadIdx.y;
const int lid = threadIdx.x;
const int numWarp = blockDim.y;
OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput,
bwdFasterDim);
const auto maxXLen = offsetCal.getMaxXLen();
const auto myXLen = offsetCal.getMyXLen();
const auto myYLen = offsetCal.getMyYLen();
scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
acc_t warpSum[V];
scalar_t inBuffer[V];
scalar_t outBuffer[V];
auto myInGradVec = reinterpret_cast<vec_t*>(myInGrad);
auto outBufferVec = reinterpret_cast<vec_t*>(outBuffer);
if (x < myXLen){
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideX = offsetCal.getStrideX();
const auto strideY = offsetCal.getStrideY();
const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;
for (int i = 0; i < V; ++i)
warpSum[i] = 0;
// Each warp reduces numYPerWarp "y" first
auto numYPerWarp = (myYLen+numWarp-1)/numWarp;
for (int warpY = 0; warpY < numYPerWarp; ++warpY){
auto y = wid*numYPerWarp + warpY;
auto myGradVec = reinterpret_cast<vec_t const *>(myGrad + y*strideY);
auto inBufferVec = reinterpret_cast<vec_t*>(inBuffer);
if (hOffset + lid*V < hiddenSize and y < myYLen){
*inBufferVec = myGradVec[lid]; // vectorized load
#pragma unroll
for (int i = 0; i < V; ++i){
warpSum[i] += inBuffer[i];
}
}
}
// transpose partial sum in SMEM and reduce further using warpReduce
for (int i = 0; i < V; ++i){
smem[lid*numWarp + wid] = warpSum[i];
__syncthreads();
auto sum = smem[wid*C10_WARP_SIZE + lid];
if (hOffset+(wid*C10_WARP_SIZE/numWarp)*V < hiddenSize){
sum = warpReduce(sum, numWarp);
if (lid % numWarp == 0){
outBuffer[i] = sum;
}
}
__syncthreads();
}
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// example of 4 warps (a, b, c, d) with 8 threads per warp
// Each warp need 8 / 4 = 2 threads to write the results.
if (lid % numWarp == 0 and hOffset+(wid*C10_WARP_SIZE/numWarp + lid/numWarp)*V < hiddenSize)
myInGradVec[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = *outBufferVec;
}
else if (wid == 0 and hOffset + lid*V < hiddenSize){
// Need to ensure the grad is zero for don't care region
myInGradVec[lid] = 0;
}
}
// Vecotrized version of transducer_joint_combined_backward
// Call transducer_joint_single_vec_backward twice on two input tensors.
// The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
// uses the rest.
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal>
__global__ void transducer_joint_combined_vec_backward(
const scalar_t *grad,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
scalar_t *fGrad,
scalar_t *gGrad) {
if (blockIdx.y < maxFLen){
transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal>(
grad,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
false,
fGrad);
}
else{
transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal>(
grad,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
true,
gGrad,
maxFLen);
}
}
torch::Tensor transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
int tileSize){
auto tensorOpt = f.options();
auto dtype = f.scalar_type();
const auto batchSize = f.size(0);
const auto maxFLen = f.size(1);
const auto maxGLen = g.size(1);
const auto hiddenSize = f.size(2);
int64_t *batchOffsetPtr = nullptr;
torch::Tensor sum;
if (!packOutput){
sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt);
batchOffsetPtr = nullptr;
}
else{
sum = torch::empty({packedBatch, hiddenSize}, tensorOpt);
batchOffsetPtr = batchOffset.data_ptr<int64_t>();
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt);
// Simple heuristics
const int numThread = std::min(128, (static_cast<int>(hiddenSize)+C10_WARP_SIZE-1)
/ C10_WARP_SIZE * C10_WARP_SIZE);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] {
if (opt == 0){
// vanilla kernel
const int threads = numThread;
const dim3 blocks(maxGLen, maxFLen, batchSize);
transducer_joint_forward<scalar_t, OffsetCalFwd>
<<<blocks, threads, 0, stream>>>(
f.data_ptr<scalar_t>(),
g.data_ptr<scalar_t>(),
fLen.data_ptr<int>(),
gLen.data_ptr<int>(),
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
sum.data_ptr<scalar_t>());
}
if (opt == 1){
// tiled version. For simplicity, assume tileF == tileG, even though the kernel can
// support more general cases.
const int threads = numThread;
const int hiddenPerBlock = numThread;
const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;
const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock,
(maxFLen+tileSize-1)/tileSize,
batchSize);
TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4,
"Expected tileSize to be in [1, 2, 4], but got ", tileSize);
switch (tileSize) {
#define LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(tile) case tile:\
transducer_joint_tiled_forward<scalar_t, tile, tile, OffsetCalFwd>\
<<<blocks, threads, 0, stream>>>(\
f.data_ptr<scalar_t>(),\
g.data_ptr<scalar_t>(),\
fLen.data_ptr<int>(),\
gLen.data_ptr<int>(),\
batchOffsetPtr,\
maxFLen,\
maxGLen,\
hiddenSize,\
hiddenPerBlock,\
packOutput,\
sum.data_ptr<scalar_t>());\
break;
LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(1);
LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(2);
LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(4);
}
}
}));
THCudaCheck(cudaGetLastError());
return sum;
}
std::vector<torch::Tensor> transducer_joint_cuda_backward(
torch::Tensor grad,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput){
auto tensorOpt = grad.options();
auto dtype = grad.scalar_type();
const int batchSize = fLen.size(0);
const int hiddenSize = grad.size(-1);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE;
torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt);
torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt);
int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr<int64_t>();
// The number "y" I would like each thread to work on
const int workPerThread = 32;
// Since the bwd for f and g have the same thread block size, we need to use the max of the two.
int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread-1) / workPerThread);
// Would like to have at least 2 warps
numWarp = std::max(2, numWarp);
// cap on the maximum number of warps allowed
numWarp = std::min(maxNumWarp, numWarp);
// Need smem for transposing the partial sum. The partial sum is in a matrix of the shape
// numWarp x warpSize
const int smemSize = numWarp * C10_WARP_SIZE;
const dim3 threads(C10_WARP_SIZE, numWarp, 1);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_cuda_backward_kernel", ([&] {
auto gradPtr = grad.data_ptr<scalar_t>();
auto fLenPtr = fLen.data_ptr<int>();
auto gLenPtr = gLen.data_ptr<int>();
auto fGradPtr = fGrad.data_ptr<scalar_t>();
auto gGradPtr = gGrad.data_ptr<scalar_t>();
// resolve the acc_t type
using acc_t = at::acc_type<scalar_t, true>;
using vec_t = uint64_t;
constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);
constexpr int vecAlignment = std::alignment_of<vec_t>::value;
// if all input and output tensors meet the alignment requirement
bool memAlign = (reinterpret_cast<uint64_t>(gradPtr) % vecAlignment == 0)
and (reinterpret_cast<uint64_t>(fGradPtr) % vecAlignment == 0)
and (reinterpret_cast<uint64_t>(gGradPtr) % vecAlignment == 0);
if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){
// If vectorization helps and the alignment requirement is met, use the vectorized
// kernel. For simplicity, hiddenSize needs to be a multiple vecFactor.
const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor),
maxFLen+maxGLen,
batchSize);
transducer_joint_combined_vec_backward
<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
fGradPtr,
gGradPtr);
}
else{
const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE,
maxFLen + maxGLen, batchSize);
transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
fGradPtr,
gGradPtr);
}
}));
return {fGrad, gGrad};
}
#include <torch/extension.h>
#include <vector>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput);
torch::Tensor transducer_loss_cuda_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput);
std::vector<torch::Tensor> transducer_loss_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor fLen,
torch::Tensor yLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput
) {
CHECK_INPUT(x);
CHECK_INPUT(label);
CHECK_INPUT(fLen);
CHECK_INPUT(yLen);
if (packedInput)
CHECK_INPUT(batchOffset);
return transducer_loss_cuda_forward(
x,
label,
fLen,
yLen,
batchOffset,
maxFLen,
blankIdx,
opt,
packedInput);
}
torch::Tensor transducer_loss_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor fLen,
torch::Tensor yLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput){
CHECK_INPUT(x);
CHECK_INPUT(label);
CHECK_INPUT(lossGrad);
CHECK_INPUT(alpha);
CHECK_INPUT(beta);
CHECK_INPUT(fLen);
CHECK_INPUT(yLen);
if (packedInput)
CHECK_INPUT(batchOffset);
return transducer_loss_cuda_backward(
x,
lossGrad,
alpha,
beta,
fLen,
yLen,
label,
batchOffset,
maxFLen,
blankIdx,
opt,
fuseSoftmaxBackward,
packedInput);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)");
m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)");
}
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h>
template<typename scalar_t>
__device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) {
// standard log-sum-exp trick is used here to provide better numerical stability
return (a >= b) ? a + std::log1p(exp(b-a)) : b + std::log1p(exp(a-b));
}
// Vanilla transducer loss function (i.e. forward-backward algorithm)
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.
// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_forward(
const scalar_t* x,
const int* label,
const int* audLen,
const int* txtLen,
const int64_t* batchOffset,
int64_t dictSize, // 64-bit indexing for data tensor
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
acc_t* alpha,
acc_t* beta,
scalar_t* loss) {
const int batch = blockIdx.y;
const int tid = threadIdx.x;
const auto myFLen = audLen[batch];
// Note that start of the sentence is added as 1 here
const auto myGLen = txtLen[batch] + 1;
const auto myLabel = label + batch * (maxGLen-1);
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
const scalar_t* myX = x + myBatchOffset * dictSize;
int u = tid;
if (blockIdx.x == 0){
// alpha path
acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;
if (u == 0)
myAlpha[0] = 0;
__syncthreads();
for (int64_t step = 1; step < myFLen+myGLen-1; ++step){
// Move along the diagonal wavefront to leverage available parallelism
for (u = tid; u < myGLen; u += blockDim.x){
int64_t t = step - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(16) in [1]
if (u == 0){
// alpha(t, u) = alpha(t-1, u) * null(t-1, u)
myAlpha[t*maxGLen + u] = myAlpha[(t-1)*maxGLen]
+ myX[((t-1)*myStrideT) * dictSize + blankIdx];
}
else if (t == 0){
// alpha(t, u-1) = alpha(t, u-1) * y(t, u-1)
myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]];
}
else{
// alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1)
acc_t current = myAlpha[(t-1)*maxGLen + u]
+ myX[((t-1)*myStrideT + u) * dictSize + blankIdx];
acc_t next = myAlpha[t*maxGLen + u - 1]
+ myX[(t*myStrideT + u - 1) * dictSize + myLabel[u - 1]];
myAlpha[t*maxGLen + u] = logSumExp(next, current);
}
}
}
__syncthreads();
}
}
else if (blockIdx.x == 1){
// beta path
acc_t* myBeta = beta + batch*maxFLen*maxGLen;
if (u == 0){
myBeta[(myFLen-1)*maxGLen + myGLen - 1] = myX[((myFLen-1)*myStrideT
+ myGLen - 1) * dictSize + blankIdx];
}
__syncthreads();
for (int64_t step = myFLen+myGLen - 3; step >= 0; --step){
for (u = tid; u < myGLen; u += blockDim.x){
int64_t t = step - u;
if (t >= 0 and t < myFLen and u >=0 and u < myGLen){
// Eq(18) in [1]
if (u == myGLen - 1){
// beta(t, u) = beta(t+1, u) * null(t, u)
myBeta[t*maxGLen + u] = myBeta[(t+1)*maxGLen + u]
+ myX[(t*myStrideT + u) * dictSize + blankIdx];
}
else if (t == myFLen - 1){
// beta(t, u) = beta(t, u+1) * y(t, u)
myBeta[t*maxGLen + u] = myBeta[t*maxGLen + u + 1]
+ myX[(t*myStrideT + u) * dictSize + myLabel[u]];
}
else{
// beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u)
acc_t current = myBeta[(t+1)*maxGLen + u]
+ myX[(t*myStrideT + u) * dictSize + blankIdx];
acc_t next = myBeta[t*maxGLen + u + 1]
+ myX[(t*myStrideT + u) * dictSize + myLabel[u]];
myBeta[t*maxGLen + u] = logSumExp(next, current);
}
}
}
__syncthreads();
}
if (tid == 0)
loss[batch] = -myBeta[0];
}
}
// transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization.
// Compared to the vanilla version, there are two optimizations:
// 1. load x in batch through loop unrolling to reduce the latency.
// 2. Use registers and shared memory to hold alpha and beta values passed from one step the next.
// For simplicity, this kernel currently only supports U <= maxThread, which should be the common
// case. For cases where U > maxThread, the vanilla kernel is used as a fallback option.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.
// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, int batchLdSize>
__global__ void transducer_loss_batch_load_forward(
const scalar_t* x,
const int* label,
const int* audLen,
const int* txtLen,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
acc_t* alpha,
acc_t* beta,
scalar_t* loss) {
const int batch = blockIdx.y;
int u = threadIdx.x;
const auto myFLen = audLen[batch];
const auto myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
const scalar_t* myX = x + myBatchOffset * dictSize;
scalar_t next[batchLdSize], current[batchLdSize];
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
if (blockIdx.x == 0){
// alpha path
acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;
// two SMEM regions for double buffering read and write data to avoid data race
acc_t * const sharedAlpha[2] = {smem, smem+maxGLen};
sharedAlpha[0][u] = 0;
__syncthreads();
if (u == 0)
myAlpha[0] = 0;
auto myAlphaLabel = (u == 0) ? 0 : label[batch*(maxGLen-1) + u - 1];
// register used to pass value to the next step for the same thread
acc_t prvStepAlpha = 0;
for (int64_t step = 1; step < myFLen+myGLen-1+batchLdSize; step += batchLdSize){
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X through loop unrolling
#pragma unroll
for (int i = 0; i < batchLdSize; ++i){
if (step+i<myFLen+myGLen-1){
// index computing
int64_t t = step + i - u;
int64_t currentId = ((t-1)*myStrideT + u) * dictSize + blankIdx;
int64_t nextId = (t*myStrideT + u - 1) * dictSize + myAlphaLabel;
// main loading loop
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
if (u == 0){
current[i] = myX[currentId];
}
else if (t == 0){
next[i] = myX[nextId];
}
else{
current[i] = myX[currentId];
next[i] = myX[nextId];
}
}
}
}
// main computing loop
for (int i = 0; i < batchLdSize; ++i){
// swap the pointer for double buffering
auto sharedAlphaRd = sharedAlpha[(step+i-1)%2];
auto sharedAlphaWr = sharedAlpha[(step+i)%2];
if (step+i<myFLen+myGLen-1){
int64_t t = step + i - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(16) in [1]
if (u == 0)
prvStepAlpha = prvStepAlpha+current[i];
else if (t == 0)
prvStepAlpha = sharedAlphaRd[u-1]+next[i];
else
prvStepAlpha = logSumExp(prvStepAlpha+current[i], sharedAlphaRd[u-1]
+ next[i]);
sharedAlphaWr[u] = prvStepAlpha;
myAlpha[t*maxGLen + u] = prvStepAlpha;
}
}
__syncthreads();
}
}
}
else if (blockIdx.x == 1){
// beta path
acc_t* myBeta = beta + batch*maxFLen*maxGLen;
// two SMEM regions for double buffering read and write data to avoid data race
acc_t * const sharedBeta[2] = {smem, smem + maxGLen};
sharedBeta[0][u] = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];
__syncthreads();
auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch*(maxGLen-1) + u];
// register used to pass value to the next step for the same thread
acc_t prvStepBeta = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];
if (u == 0)
myBeta[(myFLen-1)*maxGLen + myGLen - 1] = prvStepBeta;
for (int64_t step = 1; step < myFLen+myGLen-1; step += batchLdSize){
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X
#pragma unroll
for (int i = 0; i < batchLdSize; ++i){
if (step+i<myFLen+myGLen-1){
// index computing
int64_t t = myFLen+myGLen - (step + i) - 2 - u;
int64_t currentId = (t*myStrideT + u) * dictSize + blankIdx;
int64_t nextId = (t*myStrideT + u) * dictSize + myBetaLabel;
// main loading loop
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
if (u == myGLen - 1){
current[i] = myX[currentId];
}
else if (t == myFLen - 1){
next[i] = myX[nextId];
}
else{
current[i] = myX[currentId];
next[i] = myX[nextId];
}
}
}
}
// main computing loop
for (int i = 0; i < batchLdSize; ++i){
// swap the pointer for double buffering
auto sharedBetaRd = sharedBeta[(step+i-1)%2];
auto sharedBetaWr = sharedBeta[(step+i)%2];
if (step+i<myFLen+myGLen-1){
int64_t t = myFLen+myGLen - (step + i) - 2 - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(18) in [1]
if (u == myGLen - 1)
prvStepBeta = prvStepBeta+current[i];
else if (t == myFLen - 1)
prvStepBeta = sharedBetaRd[u+1]+next[i];
else
prvStepBeta = logSumExp(prvStepBeta+current[i], sharedBetaRd[u+1]
+ next[i]);
sharedBetaWr[u] = prvStepBeta;
myBeta[t*maxGLen + u] = prvStepBeta;
}
}
__syncthreads();
}
}
if (u == 0)
loss[batch] = -prvStepBeta;
}
}
// Vanilla transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere,
// hence only Eq(20) in [1] is implemented in this kernel.
// Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time
// Since only gradients for the correct token and null token need to be updated, gradients at other
// locations are initialized to 0.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int t = blockIdx.x;
const int batch = blockIdx.y;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
auto myX = x + (myBatchOffset + t*myStrideT)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT)*dictSize;
auto myLabel = label + batch*(maxGLen-1);
int64_t u = tid;
while (t < myFLen and u < myGLen){
// Do the update
// loss = -ln(Pr(y*|x))
acc_t grad = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
if (u != myGLen - 1)
myXGrad[u*dictSize + myLabel[u]] = -std::exp(grad + myBeta[t*maxGLen + u + 1]
+ myX[u*dictSize + myLabel[u]]);
if (t == myFLen - 1 and u == myGLen - 1)
myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myX[u*dictSize + blankIdx]);
else if (t != myFLen - 1)
myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myBeta[(t+1)*maxGLen + u]
+ myX[u*dictSize + blankIdx]);
u += blockDim.x;
}
}
// Fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_fused_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int u = blockIdx.x;
const int t = blockIdx.y;
const int batch = blockIdx.z;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
if (t < myFLen and u < myGLen){
auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myLabel = label + batch*(maxGLen-1);
// load and store shared variables in SMEM
if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u];
}
__syncthreads();
for (int64_t h = tid; h < dictSize; h += blockDim.x){
// Do the update
acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabel[u]){
myGrad -= std::exp(grad + myBeta[t*maxGLen + u + 1]);
}
else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad);
else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBeta[(t+1)*maxGLen + u]);
}
myXGrad[h] = myGrad;
}
}
else if (!packedInput){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for (int64_t h = tid; h < dictSize; h += blockDim.x){
myXGrad[h] = 0;
}
}
}
std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput){
auto scalarType = x.scalar_type();
auto tensorOpt = x.options();
const int batchSize = label.size(0);
const int maxGLen = label.size(1) + 1;
const int dictSize = x.size(-1);
TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize,
"Expected blank index to be in the range of 0 to ",
dictSize-1,
", but got ",
blankIdx);
TORCH_CHECK(opt == -1 or opt == 0 or opt == 1,
"Got an invalid optimization level ",
opt);
// The data type of alpha and beta will be resolved at dispatch time,
// hence defined here and assigned later
torch::Tensor alpha;
torch::Tensor beta;
torch::Tensor loss = torch::empty({batchSize}, tensorOpt);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;
const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock;
const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(scalarType, "transducer_loss_cuda_forward", ([&] {
// resolve accumulation type
using acc_t = at::acc_type<scalar_t, true>;
auto accType = c10::CppTypeToScalarType<acc_t>::value;
auto accTensorOpt = tensorOpt.dtype(accType);
alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);
beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);
// decide what kernel to launch based on the problem size
// if the required SMEM size or number threads exceeds the limit, fall back to the vanilla
// kernel.
const auto smemSize = 2*maxGLen*sizeof(acc_t);
const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0
: (opt == -1) ? 1 : opt;
const int threads = std::min(maxThreadPerBlock, maxGLen);
const dim3 blocks(2, batchSize, 1);
if (optFallBack == 0)
transducer_loss_forward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
label.data_ptr<int>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
loss.data_ptr<scalar_t>());
else if (optFallBack == 1)
transducer_loss_batch_load_forward<scalar_t, acc_t, 4>
<<<blocks, threads, smemSize, stream>>>(
x.data_ptr<scalar_t>(),
label.data_ptr<int>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
loss.data_ptr<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
return {alpha, beta, loss};
}
torch::Tensor transducer_loss_cuda_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput){
auto dtype = x.scalar_type();
torch::Tensor xGrad;
const int batchSize = label.size(0);
const int maxGLen = label.size(1) + 1;
const int dictSize = x.size(-1);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;
const int warpSize = deviceProperties->warpSize;
const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (fuseSoftmaxBackward){
// alloc empty tensors for performance, hence need to ensure zeros are writtern to
// don't-care region in the kernel.
xGrad = torch::empty_like(x);
// Would like each thread to work on 4 hidden units
const int workPerThread = 4;
// Don't want to have more than 128 threads per thread block
const int maxThreadPerElmt = std::min(128, maxThreadPerBlock);
const int threads = std::min(maxThreadPerElmt, std::max(warpSize,
(dictSize+workPerThread-1)/workPerThread));
const dim3 blocks(maxGLen, maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
using acc_t = at::acc_type<scalar_t, true>;
transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}));
}
else{
// for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize
// the tensor with all zeros.
xGrad = torch::zeros_like(x);
// don't launch more threads than needed.
const int threads = std::min(maxThreadPerBlock, maxGLen);
const dim3 blocks(maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
using acc_t = at::acc_type<scalar_t, true>;
transducer_loss_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}));
}
THCudaCheck(cudaGetLastError());
return xGrad;
}
import torch
import unittest
from apex.contrib.transducer import TransducerJoint
import transducer_ref
class TransducerJointTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, for_vector_kernel):
self.B = 4
T_min = 51
T_max = 101
U_min = 12
U_max = 25
if for_vector_kernel:
H = 512
else:
H = 509
dtype = torch.float16
device = "cuda"
self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
# the loss function
for b in range(self.B):
self.h_grad[b, self.f_len[b]:, :, :] = 0
self.h_grad[b, :, self.g_len[b]:, :] = 0
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)
def _pack(self, x, f_len, g_len):
B = x.size(0)
list_x = []
for b in range(B):
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(f_len * g_len, dim=0)
return x_packed
def run_transducer_joint(self, for_vector_kernel, pack_output):
self.gen_input(for_vector_kernel=for_vector_kernel)
# Generate reference
f_ref = self.f_tst.data.clone()
g_ref = self.g_tst.data.clone()
f_ref.requires_grad = True
g_ref.requires_grad = True
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output)
my_joint= TransducerJoint(pack_output=pack_output)
if not pack_output:
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len)
h_tst.backward(self.h_grad)
else:
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len,
batch_offset=batch_offset,
packed_batch=batch_offset[-1])
h_tst.backward(self.h_grad_packed)
f_grad_tst = self.f_tst.grad
g_grad_tst = self.g_tst.grad
self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))
def test_transducer_joint(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=False)
def test_transducer_joint_vec(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False)
def test_transducer_joint_pack(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True)
def test_transducer_joint_vec_pack(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True)
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import torch
import unittest
from apex.contrib.transducer import TransducerLoss
import transducer_ref
class TransducerLossTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, scalar_t):
self.B = 5
T_min = 23
T_max = 51
U_min = 12
U_max = 25
V = 16
self.blank_idx = V - 1
device = "cuda"
self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True,
device=device)
self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1
self.x_tst_packed, self.batch_offset = self._pack(self.x_tst)
# Generate reference
x_ref = self.x_tst.data.clone()
x_ref.requires_grad = True
loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0)
_, _, self.grad_ref, self.loss_ref \
= transducer_ref.transducer_loss_reference( x=x_ref,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx,
loss_grad=loss_grad)
def _pack(self, x):
list_x = []
for b in range(self.B):
list_x_row = [x[b, t, : self.y_len[b]+1] for t in range(self.f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(self.f_len * (self.y_len+1), dim=0)
return x_packed, batch_offset
def _unpack(self, x):
x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1),
dtype=x.dtype, device=x.device)
for b in range(self.B):
my_batch_offset = 0 if b == 0 else self.batch_offset[b-1]
my_f_len = self.f_len[b]
my_g_len = self.y_len[b] + 1
for t in range(my_f_len):
for u in range(my_g_len):
x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u]
return x_unpacked
def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input):
self.gen_input(scalar_t)
my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward,
packed_input=packed_input)
if not packed_input:
loss_tst = my_loss( x=self.x_tst,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx)
loss_tst.mean().backward()
grad_tst = self.x_tst.grad
else:
loss_tst = my_loss( x=self.x_tst_packed,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx,
batch_offset=self.batch_offset,
max_f_len=max(self.f_len))
loss_tst.mean().backward()
grad_tst_packed = self.x_tst_packed.grad
grad_tst = self._unpack(grad_tst_packed)
return loss_tst, grad_tst
def test_transducer_loss_fp32(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32,
fuse_softmax_backward=False,
packed_input=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5))
def test_transducer_loss_fp16(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=False,
packed_input=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import torch
import numpy as np
import pdb
def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):
def log_sum_exp(a, b):
if (a >= b):
return a + torch.log(1 + torch.exp(b-a))
else:
return b + torch.log(1 + torch.exp(a-b))
def forward_alpha(x, label, f_len, y_len, blank_idx):
B, T, U, V = x.size()
acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
alpha = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
for b in range(B):
alpha[b, 0, 0] = 0
for t in range(1, f_len[b]):
alpha[b, t, 0] = alpha[b, t-1, 0] + x[b, t-1, 0, blank_idx]
for u in range(1, y_len[b]+1):
alpha[b, 0, u] = alpha[b, 0, u-1] + x[b, 0, u-1, label[b, u-1]]
for t in range(1, f_len[b]):
for u in range(1, y_len[b]+1):
curr_ = alpha[b, t-1, u] + x[b, t-1, u, blank_idx]
next_ = alpha[b, t, u-1] + x[b, t, u-1, label[b, u-1]]
alpha[b, t, u] = log_sum_exp(curr_, next_)
return alpha
def forward_beta(x, label, f_len, y_len, blank_idx):
B, T, U, V = x.shape
acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
beta = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
for b in range(B):
beta[b, f_len[b]-1, y_len[b]] = x[b, f_len[b]-1, y_len[b], blank_idx]
for t in range(f_len[b]-2, -1, -1):
beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx]
for u in range(y_len[b]-1, -1, -1):
beta[b, f_len[b]-1, u] = beta[b, f_len[b]-1, u+1] + x[b, f_len[b]-1, u, label[b, u]]
for t in range(f_len[b]-2, -1, -1):
for u in range(y_len[b]-1, -1, -1):
curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx]
next_ = beta[b, t, u+1] + x[b, t, u, label[b, u]]
beta[b, t, u] = log_sum_exp(curr_, next_)
return beta
def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx):
grad = torch.zeros_like(x)
B, T, U, V = x.size()
for b in range(B):
common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0]
# next
for u in range(y_len[b]):
grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u]
+ beta[b, :f_len[b], u+1]
+ x[b, :f_len[b], u, label[b, u]])
# current
grad[b, :f_len[b]-1, :y_len[b]+1, blank_idx] \
= -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1]
+ beta[b, 1:f_len[b], :y_len[b]+1]
+ x[b, :f_len[b]-1, :y_len[b]+1, blank_idx])
grad[b, f_len[b]-1, y_len[b], blank_idx] = -torch.exp(common_factor[b, f_len[b]-1, y_len[b]]
+ x[b, f_len[b]-1, y_len[b], blank_idx])
return grad
x_log = torch.nn.functional.log_softmax(x, dim=-1)
alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx)
beta = forward_beta(x_log, label, f_len, y_len, blank_idx)
grad = backward(x_log, label, f_len, y_len, alpha, beta,
loss_grad, blank_idx)
x_log.backward(grad)
loss = -beta[:, 0, 0]
loss = loss.to(x.dtype)
return alpha, beta, x.grad, loss
def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output):
B, T, H = f.size()
U = g.size(1)
f_expand = f.unsqueeze(dim=2)
g_expand = g.unsqueeze(dim=1)
h = f_expand + g_expand
h.backward(h_grad)
if pack_output == False:
# intentionally set don't-care region to -1 to test if transducer joint
# write these regions to avoid NaN and inf
for b in range(B):
h[b, f_len[b]:] = -1
h[b, :, g_len[b]:] = -1
return h, f.grad, g.grad
# packing
list_to_pack = []
for b in range(B):
list_to_pack.append(h[b, :f_len[b], :g_len[b], :].reshape(-1, H))
h_packed = torch.cat(list_to_pack)
return h_packed, f.grad, g.grad
from .transducer import TransducerJoint
from .transducer import TransducerLoss
\ No newline at end of file
import torch
import transducer_loss_cuda
import transducer_joint_cuda
class TransducerJoint(torch.nn.Module):
"""Transducer joint
Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural
Networks
Arguments:
pack_output (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm.
(default: 1)
fwd_tile_size (int, optional): tile size used in forward operation. This argument will be
ignored if opt != 1. (default: 4)
"""
def __init__(self, pack_output=False, opt=1, fwd_tile_size=4):
super(TransducerJoint, self).__init__()
self.pack_output = pack_output
self.opt = opt
self.fwd_tile_size = fwd_tile_size
self.dummy_batch_offset = torch.empty(0)
def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):
"""Forward operation of transducer joint
Arguments:
f (tensor): transcription vector from encode block of shape (B, T, H).
g (tensor): prediction vector form predict block of shape (B, U, H).
f_len (tensor): length of transcription vector for each batch.
g_len (tensor): length of prediction vector minus 1 for each batch.
batch_offset (tensor, optional): tensor containing the offset of each batch
in the results. For example, batch offset can be obtained from:
batch_offset = torch.cumsum(f_len*g_len, dim=0)
This argument is required if pack_output == True, and is ignored if
pack_output == False. (default: None)
packed_batch (int, optional): the batch size after packing. This argument is
ignored if pack_output == False. (default: 0)
"""
my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset
if self.pack_output and (batch_offset is None or packed_batch == 0):
raise Exception("Please specify batch_offset and packed_batch when packing is enabled")
return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, my_batch_offset,
packed_batch, self.opt, self.fwd_tile_size)
class TransducerLoss(torch.nn.Module):
"""Transducer loss
Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural
Networks
Arguments:
fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with
softmax. (default: True)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized
algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1)
packed_input (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
"""
def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False):
super(TransducerLoss, self).__init__()
self.fuse_softmax_backward = fuse_softmax_backward
self.opt = opt
self.packed_input = packed_input
self.dummy_batch_offset = torch.empty(0)
def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None,
debug_list=None):
"""Forward operation of transducer joint
Arguments:
x (tensor): input tensor to the loss function with a shape of (B, T, U, H).
label (tensor): labels for the input data.
f_len (tensor): lengths of the inputs in the time dimension for each batch.
y_len (tensor): lengths of the labels for each batch.
blank_idx (int): index for the null symbol.
batch_offset (tensor, optional): tensor containing the offset of each batch
in the input. For example, batch offset can be obtained from:
batch_offset = torch.cumsum(f_len*(y_len+1), dim=0)
This argument is required if packed_input == True, and is ignored if
packed_input == False. (default: None)
max_f_len (int, optional): maximum length of the input in the time dimension.
For example, it can be obtained as
max_f_len = max(f_len)
This argument is required if packed_input == True, and is ignored if
packed_input == False. (default: None)
(default: None)
debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated
in the forward operation will be attached to this list for debug purpose.
(default: None)
"""
if self.packed_input:
if batch_offset is None or max_f_len is None:
raise Exception("Please specify batch_offset and max_f_len when packing is \
enabled")
my_batch_offset = batch_offset
my_max_f_len = max_f_len
else:
my_batch_offset = self.dummy_batch_offset
my_max_f_len = x.size(1)
return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len,
blank_idx, self.fuse_softmax_backward, debug_list,
self.opt, self.packed_input)
class TransducerLossFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx,
fuse_softmax_backward, debug_list, opt, packed_input):
if fuse_softmax_backward == False:
with torch.enable_grad():
x = torch.nn.functional.log_softmax(x, dim=-1)
else:
x = torch.nn.functional.log_softmax(x, dim=-1)
alpha, beta, loss = transducer_loss_cuda.forward( x, label, f_len, y_len, batch_offset,
max_f_len, blank_idx, opt, packed_input)
if debug_list == []:
debug_list += [alpha, beta]
ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset)
ctx.blank_idx = blank_idx
ctx.fuse_softmax_backward = fuse_softmax_backward
ctx.opt = opt
ctx.packed_input = packed_input
ctx.max_f_len = max_f_len
return loss
@staticmethod
def backward(ctx, loss_grad):
x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors
x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label,
batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt,
ctx.fuse_softmax_backward, ctx.packed_input)
if ctx.fuse_softmax_backward == False:
x_grad = x.backward(x_grad)
return x_grad, None, None, None, None, None, None, None, None, None, None
class TransducerJointFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, f, g, f_len, g_len, pack_output, batch_offset, packed_batch, opt,
fwd_tile_size):
h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt,
pack_output, fwd_tile_size)
ctx.save_for_backward(f_len, g_len, batch_offset)
ctx.pack_output = pack_output
ctx.max_f_len = f.size(1)
ctx.max_g_len = g.size(1)
return h
@staticmethod
def backward(ctx, loss_grad):
f_len, g_len, batch_offset = ctx.saved_tensors
f_grad, g_grad = transducer_joint_cuda.backward(loss_grad, f_len, g_len, batch_offset,
ctx.max_f_len, ctx.max_g_len,
ctx.pack_output)
return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None
...@@ -453,6 +453,31 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -453,6 +453,31 @@ if "--fast_multihead_attn" in sys.argv:
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
if "--transducer" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--transducer")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--transducer was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='transducer_joint_cuda',
sources=['apex/contrib/csrc/transducer/transducer_joint.cpp',
'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='transducer_loss_cuda',
sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
'apex/contrib/csrc/transducer/transducer_loss_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
setup( setup(
name='apex', name='apex',
version='0.1', version='0.1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment