Unverified Commit 0c2c6eea authored by Nan Zheng's avatar Nan Zheng Committed by GitHub
Browse files

Added more fusion and vectorized kernel for transducer (#1125)

* Added support for fused ReLU and dropout into transducer joint

* Reorganized code selection path in transducer joint fwd
* Added support for fused ReLU+dropout into transducer joint

* Vectorize transducer loss backward with fused softmax (#3)

* Nanz/transducer loss (#4)

* Vectorize transducer loss backward with fused softmax

* Added a predicate to avoid potential IMA

* Nanz/transducer loss (#5)

* Vectorize transducer loss backward with fused softmax

* Added a predicate to avoid potentional IMA

* Added more predicates to avoid IMAs

* Updated documentations for newly added features.

* Fixed a error in transducer.py
parent ed719967
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor transducer_joint_cuda_forward( std::vector<torch::Tensor> transducer_joint_cuda_forward(
torch::Tensor f, torch::Tensor f,
torch::Tensor g, torch::Tensor g,
torch::Tensor fLen, torch::Tensor fLen,
...@@ -14,19 +14,23 @@ torch::Tensor transducer_joint_cuda_forward( ...@@ -14,19 +14,23 @@ torch::Tensor transducer_joint_cuda_forward(
int64_t packedBatch, int64_t packedBatch,
int opt, int opt,
bool packOutput, bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize); int tileSize);
std::vector<torch::Tensor> transducer_joint_cuda_backward( std::vector<torch::Tensor> transducer_joint_cuda_backward(
torch::Tensor grad, std::vector<torch::Tensor> in,
torch::Tensor fLen, torch::Tensor fLen,
torch::Tensor gLen, torch::Tensor gLen,
torch::Tensor batchOffset, torch::Tensor batchOffset,
int maxFLen, int maxFLen,
int maxGLen, int maxGLen,
bool packOutput); bool packOutput,
float scale);
torch::Tensor transducer_joint_forward( std::vector<torch::Tensor> transducer_joint_forward(
torch::Tensor f, torch::Tensor f,
torch::Tensor g, torch::Tensor g,
torch::Tensor fLen, torch::Tensor fLen,
...@@ -35,6 +39,9 @@ torch::Tensor transducer_joint_forward( ...@@ -35,6 +39,9 @@ torch::Tensor transducer_joint_forward(
int64_t packedBatch, int64_t packedBatch,
int opt, int opt,
bool packOutput, bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize) { int tileSize) {
CHECK_INPUT(f); CHECK_INPUT(f);
CHECK_INPUT(g); CHECK_INPUT(g);
...@@ -51,30 +58,37 @@ torch::Tensor transducer_joint_forward( ...@@ -51,30 +58,37 @@ torch::Tensor transducer_joint_forward(
packedBatch, packedBatch,
opt, opt,
packOutput, packOutput,
relu,
dropout,
dropoutProb,
tileSize); tileSize);
} }
std::vector<torch::Tensor> transducer_joint_backward( std::vector<torch::Tensor> transducer_joint_backward(
torch::Tensor grad, std::vector<torch::Tensor> in,
torch::Tensor fLen, torch::Tensor fLen,
torch::Tensor gLen, torch::Tensor gLen,
torch::Tensor batchOffset, torch::Tensor batchOffset,
int maxFLen, int maxFLen,
int maxGLen, int maxGLen,
bool packOutput) { bool packOutput,
CHECK_INPUT(grad); float scale) {
for (auto t : in){
CHECK_INPUT(t);
}
CHECK_INPUT(fLen); CHECK_INPUT(fLen);
CHECK_INPUT(gLen); CHECK_INPUT(gLen);
if (packOutput) if (packOutput)
CHECK_INPUT(batchOffset); CHECK_INPUT(batchOffset);
return transducer_joint_cuda_backward( return transducer_joint_cuda_backward(
grad, in,
fLen, fLen,
gLen, gLen,
batchOffset, batchOffset,
maxFLen, maxFLen,
maxGLen, maxGLen,
packOutput); packOutput,
scale);
} }
......
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include <THC/THC.h> #include <THC/THC.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h>
#include "philox.h"
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width. // Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize. // width should be a power of 2 and should be less than warpSize.
...@@ -23,6 +27,21 @@ inline int largestPowerOfTwo(int x){ ...@@ -23,6 +27,21 @@ inline int largestPowerOfTwo(int x){
return y >> 1; return y >> 1;
} }
/*
Figure out vectorization type for masks.
Similar to how PyTorch figures out acc_t here:
aten/src/ATen/AccumulateType.h
*/
template <int V>
struct MaskVecType { };
template <> struct MaskVecType<1> { using type = uint8_t; };
template <> struct MaskVecType<2> { using type = uint16_t; };
template <> struct MaskVecType<4> { using type = uint32_t; };
template<int V>
using mvec_type = typename MaskVecType<V>::type;
// Helper class to calculate pointer offset that can be shared by different flavors of kernels. // Helper class to calculate pointer offset that can be shared by different flavors of kernels.
// For fwd, batch offset and stride are different for packing and non-packing mode. // For fwd, batch offset and stride are different for packing and non-packing mode.
struct OffsetCalFwd{ struct OffsetCalFwd{
...@@ -192,23 +211,31 @@ __global__ void transducer_joint_forward( ...@@ -192,23 +211,31 @@ __global__ void transducer_joint_forward(
} }
} }
// Tiled version of the joint forward kernel /*
// Detail of this joint function can be found in: Tiled version of the joint forward kernel
// [1] Sequence Transduction with Recurrent Neural Networks. Detail of this joint function can be found in:
[1] Sequence Transduction with Recurrent Neural Networks.
// f is a tensor of shape [batch, T, H]
// g is a tensor of shape [batch, U, H] f is a tensor of shape [batch, T, H]
// the transducer joint does g is a tensor of shape [batch, U, H]
// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1) the transducer joint does
// The resultant tensor is of shape [batch, T, U, H] sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
// Each thread is working on a tile of the shape of tileF x tileG in the result tensor. The resultant tensor is of shape [batch, T, U, H]
// The input for the tile is first loaded in the register and is reused tileG and tileF times. Each thread is working on a tile of the shape of tileF x tileG in the result tensor.
The input for the tile is first loaded in the register and is reused tileG and tileF times.
// This joint function can optionally pack the output where the output tensor with a shape of
// [B, T, U, H] is packed into [B_packed, H]. This joint function can optionally pack the output where the output tensor with a shape of
// Don't-care region (t > fLen) or (u > gLen) is removed. [B, T, U, H] is packed into [B_packed, H].
// To enable packing, the starting offset for each batch need to be specified with batchOffset. Don't-care region (t > fLen) or (u > gLen) is removed.
template <typename scalar_t, int tileF, int tileG, class OffsetCal> To enable packing, the starting offset for each batch need to be specified with batchOffset.
Optionally this joint function performs ReLU and/or dropout on the joint output, which is
controlled by arguments relu and dropout, respectively. philoxArgs is argument used for generating
pseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint
function is a masked operation, which is controlled by the template argument masked. In this case,
masks are saved to backward.
*/
template <typename scalar_t, int tileF, int tileG, int U, class OffsetCal, bool masked>
__global__ void transducer_joint_tiled_forward( __global__ void transducer_joint_tiled_forward(
const scalar_t *f, const scalar_t *f,
const scalar_t *g, const scalar_t *g,
...@@ -220,8 +247,14 @@ __global__ void transducer_joint_tiled_forward( ...@@ -220,8 +247,14 @@ __global__ void transducer_joint_tiled_forward(
int64_t hiddenSize, int64_t hiddenSize,
int64_t hiddenPerBlock, int64_t hiddenPerBlock,
bool packOutput, bool packOutput,
scalar_t *sum) { bool relu,
bool dropout,
float p,
at::PhiloxCudaState philoxArgs,
scalar_t *sum,
uint8_t *mask) {
static_assert(U == 4, "U has to be 4, as random numbers are generated in batch of 4");
const int batch = blockIdx.z; const int batch = blockIdx.z;
const int t = blockIdx.y * tileF; const int t = blockIdx.y * tileF;
...@@ -239,6 +272,17 @@ __global__ void transducer_joint_tiled_forward( ...@@ -239,6 +272,17 @@ __global__ void transducer_joint_tiled_forward(
scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset; scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset;
scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset; scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset;
scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset; scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset;
uint8_t *myMask = mask + myBatchOffset + t*strideF + u*hiddenSize + hOffset;
// The following code is only needed for dropout. We try to bypass them as much as possible.
auto seeds = masked ? at::cuda::philox::unpack(philoxArgs)
: std::make_tuple(static_cast<uint64_t>(0), static_cast<uint64_t>(0));
uint64_t tid = masked ? (static_cast<uint64_t>(blockIdx.z)*gridDim.y*gridDim.x +
blockIdx.y*gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x
: 0;
Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));
scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0;
bool dropoutMask[U];
if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){ if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){
// register buffers for tiled input reuse // register buffers for tiled input reuse
...@@ -256,8 +300,28 @@ __global__ void transducer_joint_tiled_forward( ...@@ -256,8 +300,28 @@ __global__ void transducer_joint_tiled_forward(
if (t + i < myFLen){ if (t + i < myFLen){
#pragma unroll #pragma unroll
for (int j = 0; j < tileG; ++j){ for (int j = 0; j < tileG; ++j){
if (u + j < myGLen) int idx = i*tileG + j;
mySum[i*strideF + j*hiddenSize + h] = fBuffer[i] + gBuffer[j]; if (masked and dropout and idx % U == 0){
// For performance, generate 4 random numbers in one shot
// auto rand4 = curand_uniform4(&state);
auto rand4 = uniform4(ph());
dropoutMask[0] = rand4.x < p;
dropoutMask[1] = rand4.y < p;
dropoutMask[2] = rand4.z < p;
dropoutMask[3] = rand4.w < p;
}
if (u + j < myGLen){
scalar_t out = fBuffer[i] + gBuffer[j];
if (masked){
// Apply ReLU here when relu is True
bool localMask = relu ? (out>0) : 1;
localMask = dropout ? localMask & dropoutMask[idx%U] : localMask;
out = dropout ? out*localMask*scale : out*localMask;
myMask[i*strideF + j*hiddenSize + h] = static_cast<uint8_t>(localMask);
}
mySum[i*strideF + j*hiddenSize + h] = out;
}
else if (packOutput == false and u + j < maxGLen) else if (packOutput == false and u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1; mySum[i*strideF + j*hiddenSize + h] = -1;
} }
...@@ -287,15 +351,21 @@ __global__ void transducer_joint_tiled_forward( ...@@ -287,15 +351,21 @@ __global__ void transducer_joint_tiled_forward(
} }
} }
// Bwd operation (reduction) on one input tensor. Since the operation performed for the two input /*
// tensors are exactly the same, only one kernel is needed, and the different indexing offsets Bwd operation (reduction) on one input tensor. Since the operation performed for the two input
// and strides are handled by OffsetCalBwd. tensors are exactly the same, only one kernel is needed, and the different indexing offsets
and strides are handled by OffsetCalBwd.
When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
non-packed form.
// When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
// non-packed form. and mask contains the mask information.
template <typename scalar_t, typename acc_t, class OffsetCal> */
template <typename scalar_t, typename acc_t, class OffsetCal, bool masked>
__device__ void transducer_joint_single_backward( __device__ void transducer_joint_single_backward(
const scalar_t *grad, const scalar_t *grad,
const uint8_t *mask,
const int *fLen, const int *fLen,
const int *gLen, const int *gLen,
const int64_t *batchOffset, const int64_t *batchOffset,
...@@ -304,6 +374,7 @@ __device__ void transducer_joint_single_backward( ...@@ -304,6 +374,7 @@ __device__ void transducer_joint_single_backward(
int64_t hiddenSize, int64_t hiddenSize,
bool packOutput, bool packOutput,
bool bwdFasterDim, // whether bwd on the faster moving dimension (u) bool bwdFasterDim, // whether bwd on the faster moving dimension (u)
float scale,
scalar_t *inGrad, scalar_t *inGrad,
int yBlockOffset=0) { int yBlockOffset=0) {
...@@ -331,15 +402,20 @@ __device__ void transducer_joint_single_backward( ...@@ -331,15 +402,20 @@ __device__ void transducer_joint_single_backward(
const auto myBatchOffset = offsetCal.getBatchOffset(); const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideX = offsetCal.getStrideX(); const auto strideX = offsetCal.getStrideX();
const auto strideY = offsetCal.getStrideY(); const auto strideY = offsetCal.getStrideY();
scalar_t const *myGrad = grad + myBatchOffset + x*strideX + hOffset; const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;
const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset : nullptr;
// Each warp reduces numYPerWarp "y" first // Each warp reduces numYPerWarp "y" first
acc_t warpSum = 0; acc_t warpSum = 0;
auto numYPerWarp = (myYLen+numWarp-1)/numWarp; auto numYPerWarp = (myYLen+numWarp-1)/numWarp;
#pragma unroll
for (int warpY = 0; warpY < numYPerWarp; ++warpY){ for (int warpY = 0; warpY < numYPerWarp; ++warpY){
auto y = wid*numYPerWarp + warpY; auto y = wid*numYPerWarp + warpY;
if (y < myYLen and (hOffset+lid) < hiddenSize) if (y < myYLen and (hOffset+lid) < hiddenSize)
warpSum += myGrad[y*strideY + lid]; if (masked)
warpSum += static_cast<acc_t>(myGrad[y*strideY + lid]) * myMask[y*strideY + lid] * scale;
else
warpSum += myGrad[y*strideY + lid];
} }
// transpose partial sum in SMEM and reduce further using warpReduce // transpose partial sum in SMEM and reduce further using warpReduce
...@@ -366,13 +442,18 @@ __device__ void transducer_joint_single_backward( ...@@ -366,13 +442,18 @@ __device__ void transducer_joint_single_backward(
} }
} }
// Actual bwd (reduction) kernel get launched. /*
// Call transducer_joint_single_backward twice on two input tensors. Actual bwd (reduction) kernel get launched.
// The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op Call transducer_joint_single_backward twice on two input tensors.
// uses the rest. The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
template <typename scalar_t, typename acc_t, class OffsetCal> uses the rest.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, class OffsetCal, bool masked>
__global__ void transducer_joint_combined_backward( __global__ void transducer_joint_combined_backward(
const scalar_t *grad, const scalar_t *grad,
const uint8_t *mask,
const int *fLen, const int *fLen,
const int *gLen, const int *gLen,
const int64_t *batchOffset, const int64_t *batchOffset,
...@@ -380,11 +461,13 @@ __global__ void transducer_joint_combined_backward( ...@@ -380,11 +461,13 @@ __global__ void transducer_joint_combined_backward(
int64_t maxGLen, int64_t maxGLen,
int64_t hiddenSize, int64_t hiddenSize,
bool packOutput, bool packOutput,
float scale,
scalar_t *fGrad, scalar_t *fGrad,
scalar_t *gGrad) { scalar_t *gGrad) {
if (blockIdx.y < maxFLen){ if (blockIdx.y < maxFLen){
transducer_joint_single_backward<scalar_t, acc_t, OffsetCal>( transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(
grad, grad,
mask,
fLen, fLen,
gLen, gLen,
batchOffset, batchOffset,
...@@ -393,11 +476,13 @@ __global__ void transducer_joint_combined_backward( ...@@ -393,11 +476,13 @@ __global__ void transducer_joint_combined_backward(
hiddenSize, hiddenSize,
packOutput, packOutput,
false, false,
scale,
fGrad); fGrad);
} }
else{ else{
transducer_joint_single_backward<scalar_t, acc_t, OffsetCal>( transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(
grad, grad,
mask,
fLen, fLen,
gLen, gLen,
batchOffset, batchOffset,
...@@ -406,19 +491,25 @@ __global__ void transducer_joint_combined_backward( ...@@ -406,19 +491,25 @@ __global__ void transducer_joint_combined_backward(
hiddenSize, hiddenSize,
packOutput, packOutput,
true, true,
scale,
gGrad, gGrad,
maxFLen); maxFLen);
} }
} }
// Vectorized version of transducer_joint_single_backward /*
// Doing exact same operation as transducer_joint_single_backward except the load and store are Vectorized version of transducer_joint_single_backward
// vectorized. Doing exact same operation as transducer_joint_single_backward except the load and store are
// When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a vectorized.
// non-packed form. When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal> non-packed form.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>
__device__ void transducer_joint_single_vec_backward( __device__ void transducer_joint_single_vec_backward(
const scalar_t *grad, const scalar_t *grad,
const uint8_t *mask,
const int *fLen, const int *fLen,
const int *gLen, const int *gLen,
const int64_t *batchOffset, const int64_t *batchOffset,
...@@ -427,6 +518,7 @@ __device__ void transducer_joint_single_vec_backward( ...@@ -427,6 +518,7 @@ __device__ void transducer_joint_single_vec_backward(
int64_t hiddenSize, int64_t hiddenSize,
bool packOutput, bool packOutput,
bool bwdFasterDim, bool bwdFasterDim,
float scale,
scalar_t *inGrad, scalar_t *inGrad,
int yBlockOffset=0){ int yBlockOffset=0){
...@@ -437,6 +529,9 @@ __device__ void transducer_joint_single_vec_backward( ...@@ -437,6 +529,9 @@ __device__ void transducer_joint_single_vec_backward(
const int lid = threadIdx.x; const int lid = threadIdx.x;
const int numWarp = blockDim.y; const int numWarp = blockDim.y;
// Figure out the vectorization type for mask
using mvec_t = mvec_type<V>;
OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput,
bwdFasterDim); bwdFasterDim);
const auto maxXLen = offsetCal.getMaxXLen(); const auto maxXLen = offsetCal.getMaxXLen();
...@@ -448,6 +543,7 @@ __device__ void transducer_joint_single_vec_backward( ...@@ -448,6 +543,7 @@ __device__ void transducer_joint_single_vec_backward(
acc_t warpSum[V]; acc_t warpSum[V];
scalar_t inBuffer[V]; scalar_t inBuffer[V];
uint8_t maskBuffer[V];
scalar_t outBuffer[V]; scalar_t outBuffer[V];
auto myInGradVec = reinterpret_cast<vec_t*>(myInGrad); auto myInGradVec = reinterpret_cast<vec_t*>(myInGrad);
auto outBufferVec = reinterpret_cast<vec_t*>(outBuffer); auto outBufferVec = reinterpret_cast<vec_t*>(outBuffer);
...@@ -457,6 +553,8 @@ __device__ void transducer_joint_single_vec_backward( ...@@ -457,6 +553,8 @@ __device__ void transducer_joint_single_vec_backward(
const auto strideX = offsetCal.getStrideX(); const auto strideX = offsetCal.getStrideX();
const auto strideY = offsetCal.getStrideY(); const auto strideY = offsetCal.getStrideY();
const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset; const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;
const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset
:nullptr;
for (int i = 0; i < V; ++i) for (int i = 0; i < V; ++i)
warpSum[i] = 0; warpSum[i] = 0;
...@@ -466,12 +564,22 @@ __device__ void transducer_joint_single_vec_backward( ...@@ -466,12 +564,22 @@ __device__ void transducer_joint_single_vec_backward(
for (int warpY = 0; warpY < numYPerWarp; ++warpY){ for (int warpY = 0; warpY < numYPerWarp; ++warpY){
auto y = wid*numYPerWarp + warpY; auto y = wid*numYPerWarp + warpY;
auto myGradVec = reinterpret_cast<vec_t const *>(myGrad + y*strideY); auto myGradVec = reinterpret_cast<vec_t const *>(myGrad + y*strideY);
auto myMaskVec = masked ? reinterpret_cast<mvec_t const *>(myMask + y*strideY)
: nullptr;
auto inBufferVec = reinterpret_cast<vec_t*>(inBuffer); auto inBufferVec = reinterpret_cast<vec_t*>(inBuffer);
auto maskBufferVec = reinterpret_cast<mvec_t*>(maskBuffer);
if (hOffset + lid*V < hiddenSize and y < myYLen){ if (hOffset + lid*V < hiddenSize and y < myYLen){
*inBufferVec = myGradVec[lid]; // vectorized load *inBufferVec = myGradVec[lid]; // vectorized load
#pragma unroll if (masked){
for (int i = 0; i < V; ++i){ *maskBufferVec = myMaskVec[lid];
warpSum[i] += inBuffer[i]; #pragma unroll
for (int i = 0; i < V; ++i)
warpSum[i] += static_cast<acc_t>(inBuffer[i]) * maskBuffer[i] * scale;
}
else{
#pragma unroll
for (int i = 0; i < V; ++i)
warpSum[i] += inBuffer[i];
} }
} }
} }
...@@ -506,13 +614,18 @@ __device__ void transducer_joint_single_vec_backward( ...@@ -506,13 +614,18 @@ __device__ void transducer_joint_single_vec_backward(
} }
} }
// Vecotrized version of transducer_joint_combined_backward /*
// Call transducer_joint_single_vec_backward twice on two input tensors. Vecotrized version of transducer_joint_combined_backward
// The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op Call transducer_joint_single_vec_backward twice on two input tensors.
// uses the rest. The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal> uses the rest.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>
__global__ void transducer_joint_combined_vec_backward( __global__ void transducer_joint_combined_vec_backward(
const scalar_t *grad, const scalar_t *grad,
const uint8_t *mask,
const int *fLen, const int *fLen,
const int *gLen, const int *gLen,
const int64_t *batchOffset, const int64_t *batchOffset,
...@@ -520,11 +633,13 @@ __global__ void transducer_joint_combined_vec_backward( ...@@ -520,11 +633,13 @@ __global__ void transducer_joint_combined_vec_backward(
int64_t maxGLen, int64_t maxGLen,
int64_t hiddenSize, int64_t hiddenSize,
bool packOutput, bool packOutput,
float scale,
scalar_t *fGrad, scalar_t *fGrad,
scalar_t *gGrad) { scalar_t *gGrad) {
if (blockIdx.y < maxFLen){ if (blockIdx.y < maxFLen){
transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal>( transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(
grad, grad,
mask,
fLen, fLen,
gLen, gLen,
batchOffset, batchOffset,
...@@ -533,11 +648,13 @@ __global__ void transducer_joint_combined_vec_backward( ...@@ -533,11 +648,13 @@ __global__ void transducer_joint_combined_vec_backward(
hiddenSize, hiddenSize,
packOutput, packOutput,
false, false,
scale,
fGrad); fGrad);
} }
else{ else{
transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal>( transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(
grad, grad,
mask,
fLen, fLen,
gLen, gLen,
batchOffset, batchOffset,
...@@ -546,6 +663,7 @@ __global__ void transducer_joint_combined_vec_backward( ...@@ -546,6 +663,7 @@ __global__ void transducer_joint_combined_vec_backward(
hiddenSize, hiddenSize,
packOutput, packOutput,
true, true,
scale,
gGrad, gGrad,
maxFLen); maxFLen);
} }
...@@ -554,7 +672,7 @@ __global__ void transducer_joint_combined_vec_backward( ...@@ -554,7 +672,7 @@ __global__ void transducer_joint_combined_vec_backward(
torch::Tensor transducer_joint_cuda_forward( std::vector<torch::Tensor> transducer_joint_cuda_forward(
torch::Tensor f, torch::Tensor f,
torch::Tensor g, torch::Tensor g,
torch::Tensor fLen, torch::Tensor fLen,
...@@ -563,6 +681,9 @@ torch::Tensor transducer_joint_cuda_forward( ...@@ -563,6 +681,9 @@ torch::Tensor transducer_joint_cuda_forward(
int64_t packedBatch, int64_t packedBatch,
int opt, int opt,
bool packOutput, bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize){ int tileSize){
...@@ -572,17 +693,24 @@ torch::Tensor transducer_joint_cuda_forward( ...@@ -572,17 +693,24 @@ torch::Tensor transducer_joint_cuda_forward(
const auto maxFLen = f.size(1); const auto maxFLen = f.size(1);
const auto maxGLen = g.size(1); const auto maxGLen = g.size(1);
const auto hiddenSize = f.size(2); const auto hiddenSize = f.size(2);
bool masked = dropout or relu;
int64_t *batchOffsetPtr = nullptr; int64_t *batchOffsetPtr = nullptr;
torch::Tensor sum; torch::Tensor sum, mask;
auto maskOpt = tensorOpt.dtype(torch::kUInt8);
if (!packOutput){ if (!packOutput){
sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt);
batchOffsetPtr = nullptr; batchOffsetPtr = nullptr;
if (masked)
mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt);
} }
else{ else{
sum = torch::empty({packedBatch, hiddenSize}, tensorOpt); sum = torch::empty({packedBatch, hiddenSize}, tensorOpt);
batchOffsetPtr = batchOffset.data_ptr<int64_t>(); batchOffsetPtr = batchOffset.data_ptr<int64_t>();
if (masked)
mask = torch::empty({packedBatch, hiddenSize}, maskOpt);
} }
uint8_t *maskPtr = masked ? mask.data_ptr<uint8_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -590,12 +718,13 @@ torch::Tensor transducer_joint_cuda_forward( ...@@ -590,12 +718,13 @@ torch::Tensor transducer_joint_cuda_forward(
// Simple heuristics // Simple heuristics
const int numThread = std::min(128, (static_cast<int>(hiddenSize)+C10_WARP_SIZE-1) const int numThread = std::min(128, (static_cast<int>(hiddenSize)+C10_WARP_SIZE-1)
/ C10_WARP_SIZE * C10_WARP_SIZE); / C10_WARP_SIZE * C10_WARP_SIZE);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] {
if (opt == 0){ if (opt == 0){
// vanilla kernel // vanilla kernel
const int threads = numThread; const int threads = numThread;
const dim3 blocks(maxGLen, maxFLen, batchSize); const dim3 blocks(maxGLen, maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] {
transducer_joint_forward<scalar_t, OffsetCalFwd> transducer_joint_forward<scalar_t, OffsetCalFwd>
<<<blocks, threads, 0, stream>>>( <<<blocks, threads, 0, stream>>>(
f.data_ptr<scalar_t>(), f.data_ptr<scalar_t>(),
...@@ -608,54 +737,111 @@ torch::Tensor transducer_joint_cuda_forward( ...@@ -608,54 +737,111 @@ torch::Tensor transducer_joint_cuda_forward(
hiddenSize, hiddenSize,
packOutput, packOutput,
sum.data_ptr<scalar_t>()); sum.data_ptr<scalar_t>());
} }));
if (opt == 1){ }
// tiled version. For simplicity, assume tileF == tileG, even though the kernel can if (opt == 1){
// support more general cases. // tiled version. For simplicity, assume tileF == tileG, even though the kernel can
const int threads = numThread; // support more general cases.
const int hiddenPerBlock = numThread; const int threads = numThread;
const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; const int hiddenPerBlock = numThread;
const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock, const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;
(maxFLen+tileSize-1)/tileSize, const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock,
batchSize); (maxFLen+tileSize-1)/tileSize,
batchSize);
TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4,
TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4,
"Expected tileSize to be in [1, 2, 4], but got ", tileSize); "Expected tileSize to be in [1, 2, 4], but got ", tileSize);
switch (tileSize) {
#define LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(tile) case tile:\
transducer_joint_tiled_forward<scalar_t, tile, tile, OffsetCalFwd>\
<<<blocks, threads, 0, stream>>>(\
f.data_ptr<scalar_t>(),\
g.data_ptr<scalar_t>(),\
fLen.data_ptr<int>(),\
gLen.data_ptr<int>(),\
batchOffsetPtr,\
maxFLen,\
maxGLen,\
hiddenSize,\
hiddenPerBlock,\
packOutput,\
sum.data_ptr<scalar_t>());\
break;
LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(1);
LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(2);
LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(4);
}
at::PhiloxCudaState rng_engine_inputs;
if (masked){
// set up PRG when the input is masked. rng_engine_inputs will be used as a space filler
// for non-masked calls.
// Therefore no need to initialize.
c10::optional<at::Generator> gen_;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_,
at::cuda::detail::getDefaultCUDAGenerator());
// counterOffset records how many cuRAND calls each thread makes. For a tiled kernel,
// each thread processes tileF * tileG output elements.
int64_t counterOffset = tileSize * tileSize;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(counterOffset);
}
} }
}));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] {
void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*,
int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float,
at::PhiloxCudaState, scalar_t*, uint8_t*);
if (masked){
switch (tileSize){
case 2:
kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd,
true>;
break;
case 4:
kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd,
true>;
break;
}
}
else{
switch (tileSize){
case 1:
kernel = &transducer_joint_tiled_forward<scalar_t, 1, 1, 4, OffsetCalFwd,
false>;
break;
case 2:
kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd,
false>;
break;
case 4:
kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd,
false>;
break;
}
}
kernel<<<blocks, threads, 0, stream>>>(
f.data_ptr<scalar_t>(),
g.data_ptr<scalar_t>(),
fLen.data_ptr<int>(),
gLen.data_ptr<int>(),
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
hiddenPerBlock,
packOutput,
relu,
dropout,
1.0f - dropoutProb,
rng_engine_inputs,
sum.data_ptr<scalar_t>(),
maskPtr);
}));
}
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
return sum; if (masked)
return {sum, mask};
else
return {sum};
} }
std::vector<torch::Tensor> transducer_joint_cuda_backward( std::vector<torch::Tensor> transducer_joint_cuda_backward(
torch::Tensor grad, std::vector<torch::Tensor> in,
torch::Tensor fLen, torch::Tensor fLen,
torch::Tensor gLen, torch::Tensor gLen,
torch::Tensor batchOffset, torch::Tensor batchOffset,
int maxFLen, int maxFLen,
int maxGLen, int maxGLen,
bool packOutput){ bool packOutput,
float scale){
auto grad = in[0];
bool masked = (in.size() == 2);
uint8_t *maskPtr = masked ? in[1].data_ptr<uint8_t>() : nullptr;
auto tensorOpt = grad.options(); auto tensorOpt = grad.options();
auto dtype = grad.scalar_type(); auto dtype = grad.scalar_type();
...@@ -709,35 +895,76 @@ std::vector<torch::Tensor> transducer_joint_cuda_backward( ...@@ -709,35 +895,76 @@ std::vector<torch::Tensor> transducer_joint_cuda_backward(
const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor), const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor),
maxFLen+maxGLen, maxFLen+maxGLen,
batchSize); batchSize);
transducer_joint_combined_vec_backward if (masked){
<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd> transducer_joint_combined_vec_backward
<<<blocks, threads, smemSize*sizeof(acc_t)>>>( <scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, true>
gradPtr, <<<blocks, threads, smemSize*sizeof(acc_t)>>>(
fLenPtr, gradPtr,
gLenPtr, maskPtr,
batchOffsetPtr, fLenPtr,
maxFLen, gLenPtr,
maxGLen, batchOffsetPtr,
hiddenSize, maxFLen,
packOutput, maxGLen,
fGradPtr, hiddenSize,
gGradPtr); packOutput,
scale,
fGradPtr,
gGradPtr);
}
else{
transducer_joint_combined_vec_backward
<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, false>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
} }
else{ else{
const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE, const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE,
maxFLen + maxGLen, batchSize); maxFLen + maxGLen, batchSize);
transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd> if (masked){
<<<blocks, threads, smemSize*sizeof(acc_t)>>>( transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, true>
gradPtr, <<<blocks, threads, smemSize*sizeof(acc_t)>>>(
fLenPtr, gradPtr,
gLenPtr, maskPtr,
batchOffsetPtr, fLenPtr,
maxFLen, gLenPtr,
maxGLen, batchOffsetPtr,
hiddenSize, maxFLen,
packOutput, maxGLen,
fGradPtr, hiddenSize,
gGradPtr); packOutput,
scale,
fGradPtr,
gGradPtr);
}
else{
transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, false>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
} }
})); }));
......
...@@ -408,7 +408,7 @@ __global__ void transducer_loss_fused_backward( ...@@ -408,7 +408,7 @@ __global__ void transducer_loss_fused_backward(
: batch * maxFLen * maxGLen; : batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU; __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
if (t < myFLen and u < myGLen){ if (t < myFLen and u < myGLen){
...@@ -421,6 +421,9 @@ __global__ void transducer_loss_fused_backward( ...@@ -421,6 +421,9 @@ __global__ void transducer_loss_fused_backward(
if (tid == 0){ if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u]; myBetaTU = myBeta[t*maxGLen + u];
myBetaTUp1 = myBeta[t*maxGLen + u + 1];
myBetaTp1U = myBeta[(t+1)*maxGLen + u];
myLabelShared = myLabel[u];
} }
__syncthreads(); __syncthreads();
...@@ -429,14 +432,14 @@ __global__ void transducer_loss_fused_backward( ...@@ -429,14 +432,14 @@ __global__ void transducer_loss_fused_backward(
// Do the update // Do the update
acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x)) acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU); acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabel[u]){ if (u != myGLen - 1 and h == myLabelShared){
myGrad -= std::exp(grad + myBeta[t*maxGLen + u + 1]); myGrad -= std::exp(grad + myBetaTUp1);
} }
else if (h == blankIdx){ else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1) if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad); myGrad -= std::exp(grad);
else if (t != myFLen - 1) else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBeta[(t+1)*maxGLen + u]); myGrad -= std::exp(grad + myBetaTp1U);
} }
myXGrad[h] = myGrad; myXGrad[h] = myGrad;
} }
...@@ -450,6 +453,104 @@ __global__ void transducer_loss_fused_backward( ...@@ -450,6 +453,104 @@ __global__ void transducer_loss_fused_backward(
} }
// Vectorized version of fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, typename vec_t, int V>
__global__ void transducer_loss_fused_vec_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int u = blockIdx.x;
const int t = blockIdx.y;
const int batch = blockIdx.z;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myLabel = label + batch*(maxGLen-1);
// Variabels for vectorization
scalar_t myXBuffer[V], myXGradBuffer[V];
auto myXVec = reinterpret_cast<vec_t const *>(myX);
auto myXGradVec = reinterpret_cast<vec_t*>(myXGrad);
auto myXBufferVec = reinterpret_cast<vec_t*>(myXBuffer);
auto myXGradBufferVec = reinterpret_cast<vec_t*>(myXGradBuffer);
if (t < myFLen and u < myGLen){
// load and store shared variables in SMEM
if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u];
if (t != myFLen - 1)
myBetaTp1U = myBeta[(t+1)*maxGLen + u];
if (u != myGLen - 1){
myBetaTUp1 = myBeta[t*maxGLen + u + 1];
myLabelShared = myLabel[u];
}
}
__syncthreads();
#pragma unroll
for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){
// Load myX in a vector form
*myXBufferVec = myXVec[h0/V];
// Do the update for a vector of input
#pragma unroll
for (int i = 0; i < V; ++i){
auto h = h0 + i;
acc_t grad = commonFactor + myXBuffer[i]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabelShared){
myGrad -= std::exp(grad + myBetaTUp1);
}
else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad);
else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBetaTp1U);
}
myXGradBuffer[i] = myGrad;
}
// Store myXGrad in a vector form
myXGradVec[h0/V] = *myXGradBufferVec;
}
}
else if (!packedInput){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){
myXGradVec[h0/V] = 0;
}
}
}
std::vector<torch::Tensor> transducer_loss_cuda_forward( std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x, torch::Tensor x,
...@@ -586,23 +687,51 @@ torch::Tensor transducer_loss_cuda_backward( ...@@ -586,23 +687,51 @@ torch::Tensor transducer_loss_cuda_backward(
const dim3 blocks(maxGLen, maxFLen, batchSize); const dim3 blocks(maxGLen, maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
using vec_t = uint64_t;
using acc_t = at::acc_type<scalar_t, true>; using acc_t = at::acc_type<scalar_t, true>;
transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>( constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);
x.data_ptr<scalar_t>(), constexpr int vecAlignment = std::alignment_of<vec_t>::value;
lossGrad.data_ptr<scalar_t>(), // if all input and output tensors meet the alignment requirement
audLen.data_ptr<int>(), bool memAlign = reinterpret_cast<uint64_t>(x.data_ptr<scalar_t>()) % vecAlignment == 0
txtLen.data_ptr<int>(), and reinterpret_cast<uint64_t>(xGrad.data_ptr<scalar_t>())
label.data_ptr<int>(), % vecAlignment == 0;
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(), if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){
batchOffsetPtr, transducer_loss_fused_vec_backward<scalar_t, acc_t, vec_t, vectFactor>
dictSize, <<<blocks, threads, 0, stream>>>(
blankIdx, x.data_ptr<scalar_t>(),
maxFLen, lossGrad.data_ptr<scalar_t>(),
maxGLen, audLen.data_ptr<int>(),
packedInput, txtLen.data_ptr<int>(),
xGrad.data_ptr<scalar_t>()); label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}
else{
transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}
})); }));
} }
else{ else{
......
...@@ -28,6 +28,7 @@ class TransducerJointTest(unittest.TestCase): ...@@ -28,6 +28,7 @@ class TransducerJointTest(unittest.TestCase):
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device) self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
self.dropout_prob = 0.5
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
# the loss function # the loss function
...@@ -49,30 +50,38 @@ class TransducerJointTest(unittest.TestCase): ...@@ -49,30 +50,38 @@ class TransducerJointTest(unittest.TestCase):
batch_offset = torch.cumsum(f_len * g_len, dim=0) batch_offset = torch.cumsum(f_len * g_len, dim=0)
return x_packed return x_packed
def _unpack(self, x, f_len, g_len):
batch_offset = torch.cumsum(f_len * g_len, dim=0)
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)
B = self.h_grad.size(0)
H = self.h_grad.size(-1)
for b in range(B):
my_batch_offset = 0 if b == 0 else batch_offset[b-1]
my_f_len = f_len[b]
my_g_len = g_len[b]
for t in range(my_f_len):
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len :
my_batch_offset + t*my_g_len + my_g_len]
return x_unpacked
def run_transducer_joint(self, for_vector_kernel, pack_output): def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):
self.gen_input(for_vector_kernel=for_vector_kernel) self.gen_input(for_vector_kernel=for_vector_kernel)
# Generate reference # Generate reference
f_ref = self.f_tst.data.clone() f_ref = self.f_tst.data.clone()
g_ref = self.g_tst.data.clone() g_ref = self.g_tst.data.clone()
f_ref.requires_grad = True f_ref.requires_grad = True
g_ref.requires_grad = True g_ref.requires_grad = True
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output)
my_joint= TransducerJoint(pack_output=pack_output) my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout,
dropout_prob=self.dropout_prob, probe_mask=True)
if not pack_output: if not pack_output:
h_tst = my_joint( f=self.f_tst, h_tst = my_joint( f=self.f_tst,
g=self.g_tst, g=self.g_tst,
f_len=self.f_len, f_len=self.f_len,
g_len=self.g_len) g_len=self.g_len)
h_tst.backward(self.h_grad) h_tst.backward(self.h_grad)
if dropout:
mask = my_joint.mask_probe[0]
else: else:
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0) batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
h_tst = my_joint( f=self.f_tst, h_tst = my_joint( f=self.f_tst,
...@@ -82,6 +91,22 @@ class TransducerJointTest(unittest.TestCase): ...@@ -82,6 +91,22 @@ class TransducerJointTest(unittest.TestCase):
batch_offset=batch_offset, batch_offset=batch_offset,
packed_batch=batch_offset[-1]) packed_batch=batch_offset[-1])
h_tst.backward(self.h_grad_packed) h_tst.backward(self.h_grad_packed)
if dropout:
mask_packed = my_joint.mask_probe[0]
mask = self._unpack(mask_packed, self.f_len, self.g_len)
# reference
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output,
relu=relu,
dropout=dropout,
dropout_prob=self.dropout_prob,
mask=mask if dropout else None)
f_grad_tst = self.f_tst.grad f_grad_tst = self.f_tst.grad
g_grad_tst = self.g_tst.grad g_grad_tst = self.g_tst.grad
...@@ -91,16 +116,41 @@ class TransducerJointTest(unittest.TestCase): ...@@ -91,16 +116,41 @@ class TransducerJointTest(unittest.TestCase):
self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4)) self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))
def test_transducer_joint(self): def test_transducer_joint(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=False) self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_vec(self): def test_transducer_joint_vec(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False) self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)
def test_transducer_joint_pack(self): def test_transducer_joint_pack(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True) self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_vec_pack(self): def test_transducer_joint_vec_pack(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True) self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_vec_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)
def test_transducer_joint_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_vec_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
def test_transducer_joint_vec_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)
def test_transducer_joint_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)
def test_transducer_joint_vec_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,13 +8,13 @@ class TransducerLossTest(unittest.TestCase): ...@@ -8,13 +8,13 @@ class TransducerLossTest(unittest.TestCase):
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def gen_input(self, scalar_t): def gen_input(self, scalar_t, for_vector_kernel):
self.B = 5 self.B = 5
T_min = 23 T_min = 23
T_max = 51 T_max = 51
U_min = 12 U_min = 12
U_max = 25 U_max = 25
V = 16 V = 16 if for_vector_kernel else 14
self.blank_idx = V - 1 self.blank_idx = V - 1
device = "cuda" device = "cuda"
...@@ -61,8 +61,8 @@ class TransducerLossTest(unittest.TestCase): ...@@ -61,8 +61,8 @@ class TransducerLossTest(unittest.TestCase):
x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u] x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u]
return x_unpacked return x_unpacked
def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input): def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel):
self.gen_input(scalar_t) self.gen_input(scalar_t, for_vector_kernel)
my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward, my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward,
packed_input=packed_input) packed_input=packed_input)
if not packed_input: if not packed_input:
...@@ -90,28 +90,40 @@ class TransducerLossTest(unittest.TestCase): ...@@ -90,28 +90,40 @@ class TransducerLossTest(unittest.TestCase):
def test_transducer_loss_fp32(self): def test_transducer_loss_fp32(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32, loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32,
fuse_softmax_backward=False, fuse_softmax_backward=False,
packed_input=False) packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5))
def test_transducer_loss_fp16(self): def test_transducer_loss_fp16(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=False, fuse_softmax_backward=False,
packed_input=False) packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion(self): def test_transducer_loss_fp16_backward_fusion(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True, fuse_softmax_backward=True,
packed_input=False) packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed(self): def test_transducer_loss_fp16_backward_fusion_packed(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True, fuse_softmax_backward=True,
packed_input=True) packed_input=True,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed_vec(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True,
for_vector_kernel=True)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
......
...@@ -76,12 +76,21 @@ def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad): ...@@ -76,12 +76,21 @@ def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):
return alpha, beta, x.grad, loss return alpha, beta, x.grad, loss
def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output): def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout,
dropout_prob=0, mask=None):
if dropout and mask == None:
raise NotImplementedError("mask needs to supplied to test dropout.")
B, T, H = f.size() B, T, H = f.size()
U = g.size(1) U = g.size(1)
f_expand = f.unsqueeze(dim=2) f_expand = f.unsqueeze(dim=2)
g_expand = g.unsqueeze(dim=1) g_expand = g.unsqueeze(dim=1)
h = f_expand + g_expand h = f_expand + g_expand
if relu:
h = torch.nn.functional.relu(h)
if dropout:
h *= mask
scale = 1/(1-dropout_prob)
h *= scale
h.backward(h_grad) h.backward(h_grad)
if pack_output == False: if pack_output == False:
...@@ -90,6 +99,7 @@ def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output): ...@@ -90,6 +99,7 @@ def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output):
for b in range(B): for b in range(B):
h[b, f_len[b]:] = -1 h[b, f_len[b]:] = -1
h[b, :, g_len[b]:] = -1 h[b, :, g_len[b]:] = -1
return h, f.grad, g.grad return h, f.grad, g.grad
# packing # packing
......
...@@ -10,18 +10,34 @@ class TransducerJoint(torch.nn.Module): ...@@ -10,18 +10,34 @@ class TransducerJoint(torch.nn.Module):
Arguments: Arguments:
pack_output (bool, optional): whether to pack the output in a compact form with don't-care pack_output (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False) data being removed. (default: False)
relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1
(default: False)
dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1
(default: False)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm. opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm.
(default: 1) (default: 1)
fwd_tile_size (int, optional): tile size used in forward operation. This argument will be fwd_tile_size (int, optional): tile size used in forward operation. This argument will be
ignored if opt != 1. (default: 4) ignored if opt != 1. (default: 4)
dropout_prob (float, optional): dropout probability. (default: 0.0)
probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout
operation. When this argument is set to True, the mask can be accessed through
self.mask_probe. (default: false)
""" """
def __init__(self, pack_output=False, opt=1, fwd_tile_size=4): def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4,
dropout_prob=0, probe_mask=False):
super(TransducerJoint, self).__init__() super(TransducerJoint, self).__init__()
self.pack_output = pack_output self.pack_output = pack_output
self.relu = relu
self.dropout = dropout
self.dropout_prob = dropout_prob
self.opt = opt self.opt = opt
self.fwd_tile_size = fwd_tile_size self.fwd_tile_size = fwd_tile_size
self.dummy_batch_offset = torch.empty(0) self.dummy_batch_offset = torch.empty(0)
masked = self.relu or self.dropout
self.mask_probe = [] if masked and probe_mask else None
if masked and opt != 1:
raise NotImplementedError("ReLU and dropout fusion is only supported with opt=1")
def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0): def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):
...@@ -43,8 +59,10 @@ class TransducerJoint(torch.nn.Module): ...@@ -43,8 +59,10 @@ class TransducerJoint(torch.nn.Module):
my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset
if self.pack_output and (batch_offset is None or packed_batch == 0): if self.pack_output and (batch_offset is None or packed_batch == 0):
raise Exception("Please specify batch_offset and packed_batch when packing is enabled") raise Exception("Please specify batch_offset and packed_batch when packing is enabled")
return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, my_batch_offset, dropout = self.dropout and self.training # only dropout for training
packed_batch, self.opt, self.fwd_tile_size) return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout,
my_batch_offset, packed_batch, self.opt,
self.fwd_tile_size, self.dropout_prob, self.mask_probe)
class TransducerLoss(torch.nn.Module): class TransducerLoss(torch.nn.Module):
...@@ -139,23 +157,39 @@ class TransducerLossFunc(torch.autograd.Function): ...@@ -139,23 +157,39 @@ class TransducerLossFunc(torch.autograd.Function):
class TransducerJointFunc(torch.autograd.Function): class TransducerJointFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, f, g, f_len, g_len, pack_output, batch_offset, packed_batch, opt, def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch,
fwd_tile_size): opt, fwd_tile_size, dropout_prob, mask_probe):
h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt, h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt,
pack_output, fwd_tile_size) pack_output, relu, dropout, dropout_prob, fwd_tile_size)
ctx.save_for_backward(f_len, g_len, batch_offset) masked = relu or dropout
if masked:
ctx.save_for_backward(h[1], f_len, g_len, batch_offset)
if mask_probe is not None:
mask_probe.append(h[1])
else:
ctx.save_for_backward(f_len, g_len, batch_offset)
ctx.pack_output = pack_output ctx.pack_output = pack_output
ctx.masked = relu or dropout
ctx.max_f_len = f.size(1) ctx.max_f_len = f.size(1)
ctx.max_g_len = g.size(1) ctx.max_g_len = g.size(1)
return h ctx.scale = 1 / (1-dropout_prob) if dropout and dropout_prob != 1 else 1
return h[0]
@staticmethod @staticmethod
def backward(ctx, loss_grad): def backward(ctx, loss_grad):
f_len, g_len, batch_offset = ctx.saved_tensors if ctx.masked:
f_grad, g_grad = transducer_joint_cuda.backward(loss_grad, f_len, g_len, batch_offset, mask, f_len, g_len, batch_offset = ctx.saved_tensors
ctx.max_f_len, ctx.max_g_len, inp = [loss_grad, mask]
ctx.pack_output) else:
f_len, g_len, batch_offset = ctx.saved_tensors
inp = [loss_grad]
f_grad, g_grad = transducer_joint_cuda.backward( inp, f_len, g_len, batch_offset,
ctx.max_f_len, ctx.max_g_len,
ctx.pack_output, ctx.scale)
return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \
None, None, None
...@@ -512,7 +512,8 @@ if "--transducer" in sys.argv: ...@@ -512,7 +512,8 @@ if "--transducer" in sys.argv:
'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'], 'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3',
'-I./apex/contrib/csrc/multihead_attn/'] + version_dependent_macros}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='transducer_loss_cuda', CUDAExtension(name='transducer_loss_cuda',
sources=['apex/contrib/csrc/transducer/transducer_loss.cpp', sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment