Commit f79993d9 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Merge remote-tracking branch 'upstream/master' into IFU-master-2021-10-15

parents 297ab210 1d5f7e55
#include "ATen/ATen.h"
#include <THC/THCDeviceUtils.cuh>
#include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
......
......@@ -8,11 +8,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction,
const int step,
at::Tensor step,
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float grad_scale);
at::Tensor global_scale,
at::Tensor global_grad_norm,
const float max_grad_norm);
void multi_tensor_lamb_update_weights_cuda(
int chunk_size,
......@@ -20,8 +22,10 @@ void multi_tensor_lamb_update_weights_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate,
at::Tensor update_norm_offset,
at::Tensor learning_rate,
at::Tensor per_tensor_decay,
at::Tensor global_grad_norm,
bool use_nvlamb);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -116,28 +116,36 @@ struct DistOptLAMBStage1Functor
const MATH_T* per_tensor_beta2,
const MATH_T* per_tensor_beta3,
const int* per_tensor_bias_correction,
const int step,
const int* step,
const MATH_T* per_tensor_epsilon,
adamMode_t mode,
const MATH_T* per_tensor_decay,
const float grad_scale)
const MATH_T* global_scale,
const MATH_T* global_grad_norm,
const float max_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
if (*noop_gmem == 1)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float combined_scale = *global_scale;
if (max_grad_norm > 0) {
combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6);
combined_scale = *global_scale / std::min((float) 1.0, combined_scale);
}
MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta2[tensor_num];
MATH_T beta3 = 1 - beta1;
MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, step);
beta2_correction = 1 - pow(beta2, step);
beta1_correction = 1 - pow(beta1, *step);
beta2_correction = 1 - pow(beta2, *step);
} else {
beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0;
......@@ -204,7 +212,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / grad_scale;
MATH_T scaled_grad = r_g[ii] / combined_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
......@@ -215,7 +223,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / grad_scale;
MATH_T scaled_grad = r_g[ii] / combined_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
......@@ -274,7 +282,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / grad_scale;
MATH_T scaled_grad = r_g[ii] / combined_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
......@@ -285,7 +293,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / grad_scale;
MATH_T scaled_grad = r_g[ii] / combined_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
......@@ -321,13 +329,15 @@ struct DistOptLAMBStage2Functor
TensorListMetadata<3>& tl,
const MATH_T* per_tensor_param_norm,
const MATH_T* per_tensor_update_norm,
const MATH_T learning_rate,
const long* update_norm_offset,
const MATH_T* learning_rate,
const MATH_T* per_tensor_decay,
const MATH_T* global_grad_norm,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
if (*noop_gmem == 1)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
......@@ -336,14 +346,14 @@ struct DistOptLAMBStage2Functor
MATH_T decay = per_tensor_decay[tensor_num];
MATH_T ratio = learning_rate;
MATH_T ratio = *learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != (MATH_T) 0.0))
{
MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0 && param_norm != 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]];
ratio = (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate);
}
MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
......@@ -374,7 +384,7 @@ struct DistOptLAMBStage2Functor
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]);
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]);
convert(r_p[ii], r_p_copy[ii]);
}
load_store(p, r_p, i_start, 0);
......@@ -427,11 +437,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction,
const int step,
at::Tensor step,
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float grad_scale)
at::Tensor global_scale,
at::Tensor global_grad_norm,
const float max_grad_norm)
{
using namespace at;
......@@ -448,11 +460,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_beta2.DATA_PTR<scalar_t_2>(),
per_tensor_beta3.DATA_PTR<scalar_t_2>(),
per_tensor_bias_correction.DATA_PTR<int>(),
step,
step.DATA_PTR<int>(),
per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(),
grad_scale); )))
global_scale.DATA_PTR<scalar_t_2>(),
global_grad_norm.DATA_PTR<scalar_t_2>(),
max_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -463,8 +477,10 @@ void multi_tensor_lamb_update_weights_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate,
at::Tensor update_norm_offset,
at::Tensor learning_rate,
at::Tensor per_tensor_decay,
at::Tensor global_grad_norm,
bool use_nvlamb)
{
using namespace at;
......@@ -480,8 +496,10 @@ void multi_tensor_lamb_update_weights_cuda(
DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_param_norm.DATA_PTR<scalar_t_2>(),
per_tensor_update_norm.DATA_PTR<scalar_t_2>(),
(scalar_t_2) learning_rate,
update_norm_offset.DATA_PTR<long>(),
learning_rate.DATA_PTR<scalar_t_2>(),
per_tensor_decay.DATA_PTR<scalar_t_2>(),
global_grad_norm.DATA_PTR<scalar_t_2>(),
use_nvlamb); )))
AT_CUDA_CHECK(cudaGetLastError());
......
#include <torch/extension.h>
#include <ATen/Functions.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize);
std::vector<torch::Tensor> transducer_joint_cuda_backward(
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput,
float scale);
std::vector<torch::Tensor> transducer_joint_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize) {
CHECK_INPUT(f);
CHECK_INPUT(g);
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_forward(
f,
g,
fLen,
gLen,
batchOffset,
packedBatch,
opt,
packOutput,
relu,
dropout,
dropoutProb,
tileSize);
}
std::vector<torch::Tensor> transducer_joint_backward(
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput,
float scale) {
for (auto t : in){
CHECK_INPUT(t);
}
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_backward(
in,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
packOutput,
scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)");
m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)");
}
\ No newline at end of file
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/macros/Macros.h>
#include <THC/THC.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h>
#include "philox.h"
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize.
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width=C10_WARP_SIZE){
for (unsigned offset = width/2; offset > 0; offset /= 2){
x += __shfl_down_sync(0xffffffff, x, offset, width);
}
return x;
}
inline int largestPowerOfTwo(int x){
int y = 1;
while (y <= x)
y <<= 1;
return y >> 1;
}
/*
Figure out vectorization type for masks.
Similar to how PyTorch figures out acc_t here:
aten/src/ATen/AccumulateType.h
*/
template <int V>
struct MaskVecType { };
template <> struct MaskVecType<1> { using type = uint8_t; };
template <> struct MaskVecType<2> { using type = uint16_t; };
template <> struct MaskVecType<4> { using type = uint32_t; };
template<int V>
using mvec_type = typename MaskVecType<V>::type;
// Helper class to calculate pointer offset that can be shared by different flavors of kernels.
// For fwd, batch offset and stride are different for packing and non-packing mode.
struct OffsetCalFwd{
__device__ __forceinline__ OffsetCalFwd(
int64_t batch,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t gLen,
int64_t hiddenSize,
bool packOutput) :
batch(batch),
batchOffset(batchOffset),
maxFLen(maxFLen),
maxGLen(maxGLen),
gLen(gLen),
hiddenSize(hiddenSize),
packOutput(packOutput)
{}
int64_t batch;
const int64_t *batchOffset;
int64_t maxFLen;
int64_t maxGLen;
int64_t gLen;
int64_t hiddenSize;
bool packOutput;
__device__ __forceinline__ int64_t getBatchOffset(){
return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize
: batch*maxFLen*maxGLen*hiddenSize;
}
__device__ __forceinline__ int64_t getStrideF(){
return packOutput ? gLen*hiddenSize : maxGLen*hiddenSize;
}
};
// Helper class to calculate pointer offset that can be shared by different flavors of kernels
// For bwd, batch offset and stride are different for packing and non-packing mode.
// The reducion is done for two input tensors. Therefore, generating two sets of offsets
// according to bwdFasterDim can lead to a unified implementation in the actual kernel.
struct OffsetCalBwd{
__device__ __forceinline__ OffsetCalBwd(
int64_t batch,
const int64_t *batchOffset,
const int *fLen,
const int *gLen,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim) :
batch(batch),
batchOffset(batchOffset),
maxFLen(maxFLen),
maxGLen(maxGLen),
fLen(fLen),
gLen(gLen),
hiddenSize(hiddenSize),
packOutput(packOutput),
bwdFasterDim(bwdFasterDim)
{}
int64_t batch;
const int64_t *batchOffset;
const int *fLen;
const int *gLen;
int64_t maxFLen;
int64_t maxGLen;
int64_t hiddenSize;
bool packOutput;
bool bwdFasterDim; // whether doing bwd on the faster moving dimension
__device__ __forceinline__ int64_t getBatchOffset(){
return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize
: batch*maxFLen*maxGLen*hiddenSize;
}
__device__ __forceinline__ int64_t getMaxXLen(){
return bwdFasterDim ? maxGLen : maxFLen;
}
__device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]){
return bwdFasterDim ? gLen[batch] : fLen[batch];
}
__device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]){
return bwdFasterDim ? fLen[batch] : gLen[batch];
}
__device__ __forceinline__ int64_t getStrideX(){
return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize);
}
__device__ __forceinline__ int64_t getStrideY(){
return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize;
}
};
// Vanila transducer joint forward kernel
// Detail of this joint function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// f is a tensor of shape [batch, T, H]
// g is a tensor of shape [batch, U, H]
// the transducer joint does
// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
// The resultant tensor is of shape [batch, T, U, H]
// Each thread block is working on one "batch" of data in the output tensor, [batch, t, u, :]
// This joint function can optionally pack the output where the output tensor with a shape of
// [B, T, U, H] is packed into [B_packed, H].
// Don't-care region (t > fLen) or (u > gLen) is removed.
// To enable packing, the starting offset for each batch need to be specified with batchOffset.
template <typename scalar_t, class OffsetCal>
__global__ void transducer_joint_forward(
const scalar_t *f,
const scalar_t *g,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
scalar_t *sum) {
const int batch = blockIdx.z;
const int t = blockIdx.y;
const int u = blockIdx.x;
const auto myFLen = fLen[batch];
const auto myGLen = gLen[batch];
OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideF = offsetCal.getStrideF();
scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize;
scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize;
scalar_t *mySum = sum + myBatchOffset + t*strideF + u * hiddenSize;
if (t < myFLen and u < myGLen){
#pragma unroll
for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){
if (h < hiddenSize){
mySum[h] = myF[h] + myG[h];
}
}
}
else if (packOutput == false and t < maxFLen and u < maxGLen){
// Need to write finite data to don't-care region because we instantiate the result tensor
// with torch::empty for performance reasons. Even though it is don't-care region, the
// contents need to be finite, otherwise could lead to NaN in WGRAD.
// In packing mode, this write is no longer necessary as we remove the don't-care region
// from the output.
// Picking -1 (over 0) here for ease of testing.
#pragma unroll
for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){
if (h < hiddenSize){
mySum[h] = -1;
}
}
}
}
/*
Tiled version of the joint forward kernel
Detail of this joint function can be found in:
[1] Sequence Transduction with Recurrent Neural Networks.
f is a tensor of shape [batch, T, H]
g is a tensor of shape [batch, U, H]
the transducer joint does
sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
The resultant tensor is of shape [batch, T, U, H]
Each thread is working on a tile of the shape of tileF x tileG in the result tensor.
The input for the tile is first loaded in the register and is reused tileG and tileF times.
This joint function can optionally pack the output where the output tensor with a shape of
[B, T, U, H] is packed into [B_packed, H].
Don't-care region (t > fLen) or (u > gLen) is removed.
To enable packing, the starting offset for each batch need to be specified with batchOffset.
Optionally this joint function performs ReLU and/or dropout on the joint output, which is
controlled by arguments relu and dropout, respectively. philoxArgs is argument used for generating
pseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint
function is a masked operation, which is controlled by the template argument masked. In this case,
masks are saved to backward.
*/
template <typename scalar_t, int tileF, int tileG, int U, class OffsetCal, bool masked>
__global__ void transducer_joint_tiled_forward(
const scalar_t *f,
const scalar_t *g,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
int64_t hiddenPerBlock,
bool packOutput,
bool relu,
bool dropout,
float p,
at::PhiloxCudaState philoxArgs,
scalar_t *sum,
uint8_t *mask) {
static_assert(U == 4, "U has to be 4, as random numbers are generated in batch of 4");
const int batch = blockIdx.z;
const int t = blockIdx.y * tileF;
const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;
const int u = blockIdx.x / hiddenBlock * tileG;
const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock;
const int h = threadIdx.x;
const auto myFLen = fLen[batch];
const auto myGLen = gLen[batch];
OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideF = offsetCal.getStrideF();
scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset;
scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset;
scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset;
uint8_t *myMask = mask + myBatchOffset + t*strideF + u*hiddenSize + hOffset;
// The following code is only needed for dropout. We try to bypass them as much as possible.
auto seeds = masked ? at::cuda::philox::unpack(philoxArgs)
: std::make_tuple(static_cast<uint64_t>(0), static_cast<uint64_t>(0));
uint64_t tid = masked ? (static_cast<uint64_t>(blockIdx.z)*gridDim.y*gridDim.x +
blockIdx.y*gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x
: 0;
Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));
scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0;
bool dropoutMask[U];
if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){
// register buffers for tiled input reuse
scalar_t fBuffer[tileF], gBuffer[tileG];
for (int i = 0; i < tileF; ++i){
if (t + i < myFLen)
fBuffer[i] = myF[i*hiddenSize + h];
}
for (int j = 0; j < tileG; ++j){
if (u + j < myGLen)
gBuffer[j] = myG[j*hiddenSize + h];
}
#pragma unroll
for (int i = 0; i < tileF; ++i){
if (t + i < myFLen){
#pragma unroll
for (int j = 0; j < tileG; ++j){
int idx = i*tileG + j;
if (masked and dropout and idx % U == 0){
// For performance, generate 4 random numbers in one shot
// auto rand4 = curand_uniform4(&state);
auto rand4 = uniform4(ph());
dropoutMask[0] = rand4.x < p;
dropoutMask[1] = rand4.y < p;
dropoutMask[2] = rand4.z < p;
dropoutMask[3] = rand4.w < p;
}
if (u + j < myGLen){
scalar_t out = fBuffer[i] + gBuffer[j];
if (masked){
// Apply ReLU here when relu is True
bool localMask = relu ? (out>0) : 1;
localMask = dropout ? localMask & dropoutMask[idx%U] : localMask;
out = dropout ? out*localMask*scale : out*localMask;
myMask[i*strideF + j*hiddenSize + h] = static_cast<uint8_t>(localMask);
}
mySum[i*strideF + j*hiddenSize + h] = out;
}
else if (packOutput == false and u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
else if (packOutput == false and t + i < maxFLen){
// Again need to write finite data to don't-care region
#pragma unroll
for (int j = 0; j < tileG; ++j){
if (u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
}
}
else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset+h < hiddenSize){
// Only need to ensure the finity in normal mode
#pragma unroll
for (int i = 0; i < tileF; ++i){
if (t + i < maxFLen){
#pragma unroll
for (int j = 0; j < tileG; ++j){
if (u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
}
}
}
/*
Bwd operation (reduction) on one input tensor. Since the operation performed for the two input
tensors are exactly the same, only one kernel is needed, and the different indexing offsets
and strides are handled by OffsetCalBwd.
When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
non-packed form.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, class OffsetCal, bool masked>
__device__ void transducer_joint_single_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim, // whether bwd on the faster moving dimension (u)
float scale,
scalar_t *inGrad,
int yBlockOffset=0) {
const int batch = blockIdx.z;
// For the second input tensor, this offset need to be subtracted because the first yBlockOffset
// sets of thread blocks are for the first input tensor.
const int x = blockIdx.y-yBlockOffset;
const int hOffset = blockIdx.x*C10_WARP_SIZE;
const int wid = threadIdx.y;
const int lid = threadIdx.x;
const int numWarp = blockDim.y;
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput,
bwdFasterDim);
const auto maxXLen = offsetCal.getMaxXLen();
const auto myXLen = offsetCal.getMyXLen();
const auto myYLen = offsetCal.getMyYLen();
scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;
if (x < myXLen){
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideX = offsetCal.getStrideX();
const auto strideY = offsetCal.getStrideY();
const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;
const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset : nullptr;
// Each warp reduces numYPerWarp "y" first
acc_t warpSum = 0;
auto numYPerWarp = (myYLen+numWarp-1)/numWarp;
#pragma unroll
for (int warpY = 0; warpY < numYPerWarp; ++warpY){
auto y = wid*numYPerWarp + warpY;
if (y < myYLen and (hOffset+lid) < hiddenSize)
if (masked)
warpSum += static_cast<acc_t>(myGrad[y*strideY + lid]) * myMask[y*strideY + lid] * scale;
else
warpSum += myGrad[y*strideY + lid];
}
// transpose partial sum in SMEM and reduce further using warpReduce
smem[lid*numWarp + wid] = warpSum;
__syncthreads();
auto sum = smem[wid*C10_WARP_SIZE + lid];
sum = warpReduce(sum, numWarp);
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// example of 4 warps (a, b, c, d) with 8 threads per warp
// Each warp need 8 / 4 = 2 threads to write the results.
if (hOffset+wid*C10_WARP_SIZE/numWarp+lid/numWarp < hiddenSize){
if (lid % numWarp == 0){
myInGrad[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = sum;
}
}
}
else if (wid == 0 and hOffset + lid < hiddenSize){
// Need to ensure the grad is zero for don't care region
myInGrad[lid] = 0;
}
}
/*
Actual bwd (reduction) kernel get launched.
Call transducer_joint_single_backward twice on two input tensors.
The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
uses the rest.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, class OffsetCal, bool masked>
__global__ void transducer_joint_combined_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
float scale,
scalar_t *fGrad,
scalar_t *gGrad) {
if (blockIdx.y < maxFLen){
transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
false,
scale,
fGrad);
}
else{
transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
true,
scale,
gGrad,
maxFLen);
}
}
/*
Vectorized version of transducer_joint_single_backward
Doing exact same operation as transducer_joint_single_backward except the load and store are
vectorized.
When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
non-packed form.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>
__device__ void transducer_joint_single_vec_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim,
float scale,
scalar_t *inGrad,
int yBlockOffset=0){
const int batch = blockIdx.z;
const int x = blockIdx.y - yBlockOffset;
const int hOffset = blockIdx.x*C10_WARP_SIZE*V;
const int wid = threadIdx.y;
const int lid = threadIdx.x;
const int numWarp = blockDim.y;
// Figure out the vectorization type for mask
using mvec_t = mvec_type<V>;
OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput,
bwdFasterDim);
const auto maxXLen = offsetCal.getMaxXLen();
const auto myXLen = offsetCal.getMyXLen();
const auto myYLen = offsetCal.getMyYLen();
scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
acc_t warpSum[V];
scalar_t inBuffer[V];
uint8_t maskBuffer[V];
scalar_t outBuffer[V];
auto myInGradVec = reinterpret_cast<vec_t*>(myInGrad);
auto outBufferVec = reinterpret_cast<vec_t*>(outBuffer);
if (x < myXLen){
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideX = offsetCal.getStrideX();
const auto strideY = offsetCal.getStrideY();
const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;
const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset
:nullptr;
for (int i = 0; i < V; ++i)
warpSum[i] = 0;
// Each warp reduces numYPerWarp "y" first
auto numYPerWarp = (myYLen+numWarp-1)/numWarp;
for (int warpY = 0; warpY < numYPerWarp; ++warpY){
auto y = wid*numYPerWarp + warpY;
auto myGradVec = reinterpret_cast<vec_t const *>(myGrad + y*strideY);
auto myMaskVec = masked ? reinterpret_cast<mvec_t const *>(myMask + y*strideY)
: nullptr;
auto inBufferVec = reinterpret_cast<vec_t*>(inBuffer);
auto maskBufferVec = reinterpret_cast<mvec_t*>(maskBuffer);
if (hOffset + lid*V < hiddenSize and y < myYLen){
*inBufferVec = myGradVec[lid]; // vectorized load
if (masked){
*maskBufferVec = myMaskVec[lid];
#pragma unroll
for (int i = 0; i < V; ++i)
warpSum[i] += static_cast<acc_t>(inBuffer[i]) * maskBuffer[i] * scale;
}
else{
#pragma unroll
for (int i = 0; i < V; ++i)
warpSum[i] += inBuffer[i];
}
}
}
// transpose partial sum in SMEM and reduce further using warpReduce
for (int i = 0; i < V; ++i){
smem[lid*numWarp + wid] = warpSum[i];
__syncthreads();
auto sum = smem[wid*C10_WARP_SIZE + lid];
if (hOffset+(wid*C10_WARP_SIZE/numWarp)*V < hiddenSize){
sum = warpReduce(sum, numWarp);
if (lid % numWarp == 0){
outBuffer[i] = sum;
}
}
__syncthreads();
}
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// example of 4 warps (a, b, c, d) with 8 threads per warp
// Each warp need 8 / 4 = 2 threads to write the results.
if (lid % numWarp == 0 and hOffset+(wid*C10_WARP_SIZE/numWarp + lid/numWarp)*V < hiddenSize)
myInGradVec[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = *outBufferVec;
}
else if (wid == 0 and hOffset + lid*V < hiddenSize){
// Need to ensure the grad is zero for don't care region
myInGradVec[lid] = 0;
}
}
/*
Vecotrized version of transducer_joint_combined_backward
Call transducer_joint_single_vec_backward twice on two input tensors.
The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
uses the rest.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>
__global__ void transducer_joint_combined_vec_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
float scale,
scalar_t *fGrad,
scalar_t *gGrad) {
if (blockIdx.y < maxFLen){
transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
false,
scale,
fGrad);
}
else{
transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
true,
scale,
gGrad,
maxFLen);
}
}
std::vector<torch::Tensor> transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize){
auto tensorOpt = f.options();
auto dtype = f.scalar_type();
const auto batchSize = f.size(0);
const auto maxFLen = f.size(1);
const auto maxGLen = g.size(1);
const auto hiddenSize = f.size(2);
bool masked = dropout or relu;
int64_t *batchOffsetPtr = nullptr;
torch::Tensor sum, mask;
auto maskOpt = tensorOpt.dtype(torch::kUInt8);
if (!packOutput){
sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt);
batchOffsetPtr = nullptr;
if (masked)
mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt);
}
else{
sum = torch::empty({packedBatch, hiddenSize}, tensorOpt);
batchOffsetPtr = batchOffset.data_ptr<int64_t>();
if (masked)
mask = torch::empty({packedBatch, hiddenSize}, maskOpt);
}
uint8_t *maskPtr = masked ? mask.data_ptr<uint8_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt);
// Simple heuristics
const int numThread = std::min(128, (static_cast<int>(hiddenSize)+C10_WARP_SIZE-1)
/ C10_WARP_SIZE * C10_WARP_SIZE);
if (opt == 0){
// vanilla kernel
const int threads = numThread;
const dim3 blocks(maxGLen, maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] {
transducer_joint_forward<scalar_t, OffsetCalFwd>
<<<blocks, threads, 0, stream>>>(
f.data_ptr<scalar_t>(),
g.data_ptr<scalar_t>(),
fLen.data_ptr<int>(),
gLen.data_ptr<int>(),
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
sum.data_ptr<scalar_t>());
}));
}
if (opt == 1){
// tiled version. For simplicity, assume tileF == tileG, even though the kernel can
// support more general cases.
const int threads = numThread;
const int hiddenPerBlock = numThread;
const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;
const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock,
(maxFLen+tileSize-1)/tileSize,
batchSize);
TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4,
"Expected tileSize to be in [1, 2, 4], but got ", tileSize);
at::PhiloxCudaState rng_engine_inputs;
if (masked){
// set up PRG when the input is masked. rng_engine_inputs will be used as a space filler
// for non-masked calls.
// Therefore no need to initialize.
c10::optional<at::Generator> gen_;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_,
at::cuda::detail::getDefaultCUDAGenerator());
// counterOffset records how many cuRAND calls each thread makes. For a tiled kernel,
// each thread processes tileF * tileG output elements.
int64_t counterOffset = tileSize * tileSize;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(counterOffset);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] {
void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*,
int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float,
at::PhiloxCudaState, scalar_t*, uint8_t*);
if (masked){
switch (tileSize){
case 2:
kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd,
true>;
break;
case 4:
kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd,
true>;
break;
}
}
else{
switch (tileSize){
case 1:
kernel = &transducer_joint_tiled_forward<scalar_t, 1, 1, 4, OffsetCalFwd,
false>;
break;
case 2:
kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd,
false>;
break;
case 4:
kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd,
false>;
break;
}
}
kernel<<<blocks, threads, 0, stream>>>(
f.data_ptr<scalar_t>(),
g.data_ptr<scalar_t>(),
fLen.data_ptr<int>(),
gLen.data_ptr<int>(),
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
hiddenPerBlock,
packOutput,
relu,
dropout,
1.0f - dropoutProb,
rng_engine_inputs,
sum.data_ptr<scalar_t>(),
maskPtr);
}));
}
THCudaCheck(cudaGetLastError());
if (masked)
return {sum, mask};
else
return {sum};
}
std::vector<torch::Tensor> transducer_joint_cuda_backward(
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput,
float scale){
auto grad = in[0];
bool masked = (in.size() == 2);
uint8_t *maskPtr = masked ? in[1].data_ptr<uint8_t>() : nullptr;
auto tensorOpt = grad.options();
auto dtype = grad.scalar_type();
const int batchSize = fLen.size(0);
const int hiddenSize = grad.size(-1);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE;
torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt);
torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt);
int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr<int64_t>();
// The number "y" I would like each thread to work on
const int workPerThread = 32;
// Since the bwd for f and g have the same thread block size, we need to use the max of the two.
int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread-1) / workPerThread);
// Would like to have at least 2 warps
numWarp = std::max(2, numWarp);
// cap on the maximum number of warps allowed
numWarp = std::min(maxNumWarp, numWarp);
// Need smem for transposing the partial sum. The partial sum is in a matrix of the shape
// numWarp x warpSize
const int smemSize = numWarp * C10_WARP_SIZE;
const dim3 threads(C10_WARP_SIZE, numWarp, 1);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_cuda_backward_kernel", ([&] {
auto gradPtr = grad.data_ptr<scalar_t>();
auto fLenPtr = fLen.data_ptr<int>();
auto gLenPtr = gLen.data_ptr<int>();
auto fGradPtr = fGrad.data_ptr<scalar_t>();
auto gGradPtr = gGrad.data_ptr<scalar_t>();
// resolve the acc_t type
using acc_t = at::acc_type<scalar_t, true>;
using vec_t = uint64_t;
constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);
constexpr int vecAlignment = std::alignment_of<vec_t>::value;
// if all input and output tensors meet the alignment requirement
bool memAlign = (reinterpret_cast<uint64_t>(gradPtr) % vecAlignment == 0)
and (reinterpret_cast<uint64_t>(fGradPtr) % vecAlignment == 0)
and (reinterpret_cast<uint64_t>(gGradPtr) % vecAlignment == 0);
if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){
// If vectorization helps and the alignment requirement is met, use the vectorized
// kernel. For simplicity, hiddenSize needs to be a multiple vecFactor.
const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor),
maxFLen+maxGLen,
batchSize);
if (masked){
transducer_joint_combined_vec_backward
<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, true>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
else{
transducer_joint_combined_vec_backward
<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, false>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
}
else{
const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE,
maxFLen + maxGLen, batchSize);
if (masked){
transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, true>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
else{
transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, false>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
}
}));
return {fGrad, gGrad};
}
#include <torch/extension.h>
#include <vector>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput);
torch::Tensor transducer_loss_cuda_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput);
std::vector<torch::Tensor> transducer_loss_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor fLen,
torch::Tensor yLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput
) {
CHECK_INPUT(x);
CHECK_INPUT(label);
CHECK_INPUT(fLen);
CHECK_INPUT(yLen);
if (packedInput)
CHECK_INPUT(batchOffset);
return transducer_loss_cuda_forward(
x,
label,
fLen,
yLen,
batchOffset,
maxFLen,
blankIdx,
opt,
packedInput);
}
torch::Tensor transducer_loss_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor fLen,
torch::Tensor yLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput){
CHECK_INPUT(x);
CHECK_INPUT(label);
CHECK_INPUT(lossGrad);
CHECK_INPUT(alpha);
CHECK_INPUT(beta);
CHECK_INPUT(fLen);
CHECK_INPUT(yLen);
if (packedInput)
CHECK_INPUT(batchOffset);
return transducer_loss_cuda_backward(
x,
lossGrad,
alpha,
beta,
fLen,
yLen,
label,
batchOffset,
maxFLen,
blankIdx,
opt,
fuseSoftmaxBackward,
packedInput);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)");
m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)");
}
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h>
template<typename scalar_t>
__device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) {
// standard log-sum-exp trick is used here to provide better numerical stability
return (a >= b) ? a + std::log1p(exp(b-a)) : b + std::log1p(exp(a-b));
}
// Vanilla transducer loss function (i.e. forward-backward algorithm)
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.
// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_forward(
const scalar_t* x,
const int* label,
const int* audLen,
const int* txtLen,
const int64_t* batchOffset,
int64_t dictSize, // 64-bit indexing for data tensor
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
acc_t* alpha,
acc_t* beta,
scalar_t* loss) {
const int batch = blockIdx.y;
const int tid = threadIdx.x;
const auto myFLen = audLen[batch];
// Note that start of the sentence is added as 1 here
const auto myGLen = txtLen[batch] + 1;
const auto myLabel = label + batch * (maxGLen-1);
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
const scalar_t* myX = x + myBatchOffset * dictSize;
int u = tid;
if (blockIdx.x == 0){
// alpha path
acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;
if (u == 0)
myAlpha[0] = 0;
__syncthreads();
for (int64_t step = 1; step < myFLen+myGLen-1; ++step){
// Move along the diagonal wavefront to leverage available parallelism
for (u = tid; u < myGLen; u += blockDim.x){
int64_t t = step - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(16) in [1]
if (u == 0){
// alpha(t, u) = alpha(t-1, u) * null(t-1, u)
myAlpha[t*maxGLen + u] = myAlpha[(t-1)*maxGLen]
+ myX[((t-1)*myStrideT) * dictSize + blankIdx];
}
else if (t == 0){
// alpha(t, u-1) = alpha(t, u-1) * y(t, u-1)
myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]];
}
else{
// alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1)
acc_t current = myAlpha[(t-1)*maxGLen + u]
+ myX[((t-1)*myStrideT + u) * dictSize + blankIdx];
acc_t next = myAlpha[t*maxGLen + u - 1]
+ myX[(t*myStrideT + u - 1) * dictSize + myLabel[u - 1]];
myAlpha[t*maxGLen + u] = logSumExp(next, current);
}
}
}
__syncthreads();
}
}
else if (blockIdx.x == 1){
// beta path
acc_t* myBeta = beta + batch*maxFLen*maxGLen;
if (u == 0){
myBeta[(myFLen-1)*maxGLen + myGLen - 1] = myX[((myFLen-1)*myStrideT
+ myGLen - 1) * dictSize + blankIdx];
}
__syncthreads();
for (int64_t step = myFLen+myGLen - 3; step >= 0; --step){
for (u = tid; u < myGLen; u += blockDim.x){
int64_t t = step - u;
if (t >= 0 and t < myFLen and u >=0 and u < myGLen){
// Eq(18) in [1]
if (u == myGLen - 1){
// beta(t, u) = beta(t+1, u) * null(t, u)
myBeta[t*maxGLen + u] = myBeta[(t+1)*maxGLen + u]
+ myX[(t*myStrideT + u) * dictSize + blankIdx];
}
else if (t == myFLen - 1){
// beta(t, u) = beta(t, u+1) * y(t, u)
myBeta[t*maxGLen + u] = myBeta[t*maxGLen + u + 1]
+ myX[(t*myStrideT + u) * dictSize + myLabel[u]];
}
else{
// beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u)
acc_t current = myBeta[(t+1)*maxGLen + u]
+ myX[(t*myStrideT + u) * dictSize + blankIdx];
acc_t next = myBeta[t*maxGLen + u + 1]
+ myX[(t*myStrideT + u) * dictSize + myLabel[u]];
myBeta[t*maxGLen + u] = logSumExp(next, current);
}
}
}
__syncthreads();
}
if (tid == 0)
loss[batch] = -myBeta[0];
}
}
// transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization.
// Compared to the vanilla version, there are two optimizations:
// 1. load x in batch through loop unrolling to reduce the latency.
// 2. Use registers and shared memory to hold alpha and beta values passed from one step the next.
// For simplicity, this kernel currently only supports U <= maxThread, which should be the common
// case. For cases where U > maxThread, the vanilla kernel is used as a fallback option.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.
// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, int batchLdSize>
__global__ void transducer_loss_batch_load_forward(
const scalar_t* x,
const int* label,
const int* audLen,
const int* txtLen,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
acc_t* alpha,
acc_t* beta,
scalar_t* loss) {
const int batch = blockIdx.y;
int u = threadIdx.x;
const auto myFLen = audLen[batch];
const auto myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
const scalar_t* myX = x + myBatchOffset * dictSize;
scalar_t next[batchLdSize], current[batchLdSize];
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
if (blockIdx.x == 0){
// alpha path
acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;
// two SMEM regions for double buffering read and write data to avoid data race
acc_t * const sharedAlpha[2] = {smem, smem+maxGLen};
sharedAlpha[0][u] = 0;
__syncthreads();
if (u == 0)
myAlpha[0] = 0;
auto myAlphaLabel = (u == 0) ? 0 : label[batch*(maxGLen-1) + u - 1];
// register used to pass value to the next step for the same thread
acc_t prvStepAlpha = 0;
for (int64_t step = 1; step < myFLen+myGLen-1+batchLdSize; step += batchLdSize){
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X through loop unrolling
#pragma unroll
for (int i = 0; i < batchLdSize; ++i){
if (step+i<myFLen+myGLen-1){
// index computing
int64_t t = step + i - u;
int64_t currentId = ((t-1)*myStrideT + u) * dictSize + blankIdx;
int64_t nextId = (t*myStrideT + u - 1) * dictSize + myAlphaLabel;
// main loading loop
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
if (u == 0){
current[i] = myX[currentId];
}
else if (t == 0){
next[i] = myX[nextId];
}
else{
current[i] = myX[currentId];
next[i] = myX[nextId];
}
}
}
}
// main computing loop
for (int i = 0; i < batchLdSize; ++i){
// swap the pointer for double buffering
auto sharedAlphaRd = sharedAlpha[(step+i-1)%2];
auto sharedAlphaWr = sharedAlpha[(step+i)%2];
if (step+i<myFLen+myGLen-1){
int64_t t = step + i - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(16) in [1]
if (u == 0)
prvStepAlpha = prvStepAlpha+current[i];
else if (t == 0)
prvStepAlpha = sharedAlphaRd[u-1]+next[i];
else
prvStepAlpha = logSumExp(prvStepAlpha+current[i], sharedAlphaRd[u-1]
+ next[i]);
sharedAlphaWr[u] = prvStepAlpha;
myAlpha[t*maxGLen + u] = prvStepAlpha;
}
}
__syncthreads();
}
}
}
else if (blockIdx.x == 1){
// beta path
acc_t* myBeta = beta + batch*maxFLen*maxGLen;
// two SMEM regions for double buffering read and write data to avoid data race
acc_t * const sharedBeta[2] = {smem, smem + maxGLen};
sharedBeta[0][u] = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];
__syncthreads();
auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch*(maxGLen-1) + u];
// register used to pass value to the next step for the same thread
acc_t prvStepBeta = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];
if (u == 0)
myBeta[(myFLen-1)*maxGLen + myGLen - 1] = prvStepBeta;
for (int64_t step = 1; step < myFLen+myGLen-1; step += batchLdSize){
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X
#pragma unroll
for (int i = 0; i < batchLdSize; ++i){
if (step+i<myFLen+myGLen-1){
// index computing
int64_t t = myFLen+myGLen - (step + i) - 2 - u;
int64_t currentId = (t*myStrideT + u) * dictSize + blankIdx;
int64_t nextId = (t*myStrideT + u) * dictSize + myBetaLabel;
// main loading loop
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
if (u == myGLen - 1){
current[i] = myX[currentId];
}
else if (t == myFLen - 1){
next[i] = myX[nextId];
}
else{
current[i] = myX[currentId];
next[i] = myX[nextId];
}
}
}
}
// main computing loop
for (int i = 0; i < batchLdSize; ++i){
// swap the pointer for double buffering
auto sharedBetaRd = sharedBeta[(step+i-1)%2];
auto sharedBetaWr = sharedBeta[(step+i)%2];
if (step+i<myFLen+myGLen-1){
int64_t t = myFLen+myGLen - (step + i) - 2 - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(18) in [1]
if (u == myGLen - 1)
prvStepBeta = prvStepBeta+current[i];
else if (t == myFLen - 1)
prvStepBeta = sharedBetaRd[u+1]+next[i];
else
prvStepBeta = logSumExp(prvStepBeta+current[i], sharedBetaRd[u+1]
+ next[i]);
sharedBetaWr[u] = prvStepBeta;
myBeta[t*maxGLen + u] = prvStepBeta;
}
}
__syncthreads();
}
}
if (u == 0)
loss[batch] = -prvStepBeta;
}
}
// Vanilla transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere,
// hence only Eq(20) in [1] is implemented in this kernel.
// Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time
// Since only gradients for the correct token and null token need to be updated, gradients at other
// locations are initialized to 0.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int t = blockIdx.x;
const int batch = blockIdx.y;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
auto myX = x + (myBatchOffset + t*myStrideT)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT)*dictSize;
auto myLabel = label + batch*(maxGLen-1);
int64_t u = tid;
while (t < myFLen and u < myGLen){
// Do the update
// loss = -ln(Pr(y*|x))
acc_t grad = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
if (u != myGLen - 1)
myXGrad[u*dictSize + myLabel[u]] = -std::exp(grad + myBeta[t*maxGLen + u + 1]
+ myX[u*dictSize + myLabel[u]]);
if (t == myFLen - 1 and u == myGLen - 1)
myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myX[u*dictSize + blankIdx]);
else if (t != myFLen - 1)
myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myBeta[(t+1)*maxGLen + u]
+ myX[u*dictSize + blankIdx]);
u += blockDim.x;
}
}
// Fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_fused_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int u = blockIdx.x;
const int t = blockIdx.y;
const int batch = blockIdx.z;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
if (t < myFLen and u < myGLen){
auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myLabel = label + batch*(maxGLen-1);
// load and store shared variables in SMEM
if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u];
myBetaTUp1 = myBeta[t*maxGLen + u + 1];
myBetaTp1U = myBeta[(t+1)*maxGLen + u];
myLabelShared = myLabel[u];
}
__syncthreads();
for (int64_t h = tid; h < dictSize; h += blockDim.x){
// Do the update
acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabelShared){
myGrad -= std::exp(grad + myBetaTUp1);
}
else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad);
else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBetaTp1U);
}
myXGrad[h] = myGrad;
}
}
else if (!packedInput){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for (int64_t h = tid; h < dictSize; h += blockDim.x){
myXGrad[h] = 0;
}
}
}
// Vectorized version of fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, typename vec_t, int V>
__global__ void transducer_loss_fused_vec_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int u = blockIdx.x;
const int t = blockIdx.y;
const int batch = blockIdx.z;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myLabel = label + batch*(maxGLen-1);
// Variabels for vectorization
scalar_t myXBuffer[V], myXGradBuffer[V];
auto myXVec = reinterpret_cast<vec_t const *>(myX);
auto myXGradVec = reinterpret_cast<vec_t*>(myXGrad);
auto myXBufferVec = reinterpret_cast<vec_t*>(myXBuffer);
auto myXGradBufferVec = reinterpret_cast<vec_t*>(myXGradBuffer);
if (t < myFLen and u < myGLen){
// load and store shared variables in SMEM
if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u];
if (t != myFLen - 1)
myBetaTp1U = myBeta[(t+1)*maxGLen + u];
if (u != myGLen - 1){
myBetaTUp1 = myBeta[t*maxGLen + u + 1];
myLabelShared = myLabel[u];
}
}
__syncthreads();
#pragma unroll
for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){
// Load myX in a vector form
*myXBufferVec = myXVec[h0/V];
// Do the update for a vector of input
#pragma unroll
for (int i = 0; i < V; ++i){
auto h = h0 + i;
acc_t grad = commonFactor + myXBuffer[i]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabelShared){
myGrad -= std::exp(grad + myBetaTUp1);
}
else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad);
else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBetaTp1U);
}
myXGradBuffer[i] = myGrad;
}
// Store myXGrad in a vector form
myXGradVec[h0/V] = *myXGradBufferVec;
}
}
else if (!packedInput){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){
myXGradVec[h0/V] = 0;
}
}
}
std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput){
auto scalarType = x.scalar_type();
auto tensorOpt = x.options();
const int batchSize = label.size(0);
const int maxGLen = label.size(1) + 1;
const int dictSize = x.size(-1);
TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize,
"Expected blank index to be in the range of 0 to ",
dictSize-1,
", but got ",
blankIdx);
TORCH_CHECK(opt == -1 or opt == 0 or opt == 1,
"Got an invalid optimization level ",
opt);
// The data type of alpha and beta will be resolved at dispatch time,
// hence defined here and assigned later
torch::Tensor alpha;
torch::Tensor beta;
torch::Tensor loss = torch::empty({batchSize}, tensorOpt);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;
const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock;
const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(scalarType, "transducer_loss_cuda_forward", ([&] {
// resolve accumulation type
using acc_t = at::acc_type<scalar_t, true>;
auto accType = c10::CppTypeToScalarType<acc_t>::value;
auto accTensorOpt = tensorOpt.dtype(accType);
alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);
beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);
// decide what kernel to launch based on the problem size
// if the required SMEM size or number threads exceeds the limit, fall back to the vanilla
// kernel.
const auto smemSize = 2*maxGLen*sizeof(acc_t);
const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0
: (opt == -1) ? 1 : opt;
const int threads = std::min(maxThreadPerBlock, maxGLen);
const dim3 blocks(2, batchSize, 1);
if (optFallBack == 0)
transducer_loss_forward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
label.data_ptr<int>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
loss.data_ptr<scalar_t>());
else if (optFallBack == 1)
transducer_loss_batch_load_forward<scalar_t, acc_t, 4>
<<<blocks, threads, smemSize, stream>>>(
x.data_ptr<scalar_t>(),
label.data_ptr<int>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
loss.data_ptr<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
return {alpha, beta, loss};
}
torch::Tensor transducer_loss_cuda_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput){
auto dtype = x.scalar_type();
torch::Tensor xGrad;
const int batchSize = label.size(0);
const int maxGLen = label.size(1) + 1;
const int dictSize = x.size(-1);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;
const int warpSize = deviceProperties->warpSize;
const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (fuseSoftmaxBackward){
// alloc empty tensors for performance, hence need to ensure zeros are writtern to
// don't-care region in the kernel.
xGrad = torch::empty_like(x);
// Would like each thread to work on 4 hidden units
const int workPerThread = 4;
// Don't want to have more than 128 threads per thread block
const int maxThreadPerElmt = std::min(128, maxThreadPerBlock);
const int threads = std::min(maxThreadPerElmt, std::max(warpSize,
(dictSize+workPerThread-1)/workPerThread));
const dim3 blocks(maxGLen, maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
using vec_t = uint64_t;
using acc_t = at::acc_type<scalar_t, true>;
constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);
constexpr int vecAlignment = std::alignment_of<vec_t>::value;
// if all input and output tensors meet the alignment requirement
bool memAlign = reinterpret_cast<uint64_t>(x.data_ptr<scalar_t>()) % vecAlignment == 0
and reinterpret_cast<uint64_t>(xGrad.data_ptr<scalar_t>())
% vecAlignment == 0;
if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){
transducer_loss_fused_vec_backward<scalar_t, acc_t, vec_t, vectFactor>
<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}
else{
transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}
}));
}
else{
// for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize
// the tensor with all zeros.
xGrad = torch::zeros_like(x);
// don't launch more threads than needed.
const int threads = std::min(maxThreadPerBlock, maxGLen);
const dim3 blocks(maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
using acc_t = at::acc_type<scalar_t, true>;
transducer_loss_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}));
}
THCudaCheck(cudaGetLastError());
return xGrad;
}
......@@ -78,7 +78,6 @@
#include <THC/THC.h>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include "type_shim.h"
#include "compat.h"
......
from .fmha import FMHAFun
###############################################################################
# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
###############################################################################
import torch
import torch.nn.functional as F
import fmhalib as mha
class FMHAFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training):
batch_size = cu_seqlens.numel() - 1
if batch_size < 4:
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
else:
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
ctx.save_for_backward(qkv, S_dmask)
ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
return context
@staticmethod
def backward(ctx, dout):
qkv, S_dmask = ctx.saved_tensors
batch_size = ctx.cu_seqlens.numel() - 1
if batch_size < 4:
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
else:
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
return dqkv, None, None, None, None, None, None
class FMHA(torch.nn.Module):
def __init__(self, config):
super(FMHA, self).__init__()
self.p_dropout = config.attention_probs_dropout_prob
self.h = config.num_attention_heads
self.hidden_size = config.hidden_size
self.d = self.hidden_size // self.h
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
def forward(self, qkv, cu_seqlens, max_s, is_training=True):
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training)
return ctx.view(-1, self.hidden_size)
import os
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
import torch.distributed.distributed_c10d as c10d
class DistributedFusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
This version of fused LAMB implements 2 fusions.
* Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
you may choose any ``opt_level``::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, ``opt_level="O1"`` is recommended.
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
......@@ -56,24 +59,38 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
clip_grad_norm (boolean, optional): whether to handle gradient clipping
(default: True)
step_supports_amp_scaling(boolean, optional): whether to use customized
gradient unscaling logic (default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
class AtomicCounter(object):
def __init__(self):
self.value = 0
self.order = []
import threading
self._lock = threading.Lock()
def add(self, idx):
with self._lock:
self.value += 1
self.order.append(idx)
def __init__(self, params,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True,
adam_w_mode=True, use_nvlamb=False,
step_supports_amp_scaling=True, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False):
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
e5m2_allgather=False, verbose=False, clip_after_ar=True,
full_ar=False, set_param_views_to_flat_buffer=False, skip_allgather=False,
fuse_scale=False, param_order=None, nccl_allgather_channels=0):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
......@@ -81,46 +98,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
super(DistributedFusedLAMB, self).__init__(params, defaults)
self._init_args = {
'lr': lr,
'bias_correction': bias_correction,
'grad_averaging': grad_averaging,
'betas': betas,
'eps': eps,
'weight_decay': weight_decay,
'max_grad_norm': max_grad_norm,
'adam_w_mode': adam_w_mode,
'use_nvlamb': use_nvlamb,
'clip_grad_norm': clip_grad_norm,
'amp_scale_adjustment': amp_scale_adjustment,
'overlap_reductions': overlap_reductions,
'dwu_group_size': dwu_group_size,
'dwu_num_blocks': dwu_num_blocks,
'dwu_num_chunks': dwu_num_chunks,
'dwu_num_rs_pg': dwu_num_rs_pg,
'dwu_num_ar_pg': dwu_num_ar_pg,
'dwu_num_ag_pg': dwu_num_ag_pg,
'e5m2_allgather': e5m2_allgather}
self._init_done = False
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def __first_step_init__(self,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False):
global fused_adam_cuda, distributed_lamb_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term
......@@ -128,9 +109,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import amp_C
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self._grad_averaging = grad_averaging
self._adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb
self._clip_grad_norm = clip_grad_norm
self._step_supports_amp_scaling = step_supports_amp_scaling
self._is_accumulation_step = False
self._last_step = False
self._overlap_reductions = overlap_reductions
......@@ -138,44 +120,176 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather
self._verbose = verbose
self._clip_after_ar = clip_after_ar
self._full_ar = full_ar
self._fuse_scale = fuse_scale
self._L2_grad_norm = None
self._set_flat_param_view = set_param_views_to_flat_buffer
self._skip_ag = skip_allgather
self._fused_norm = fused_norm
self._current_process_group = c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda')
self._resume_from_checkpoint = False
self._step = torch.cuda.IntTensor([0])
# Master weight, moment, gradient buffers
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._full_ar: # full all reduce, only need AR and AG groups
self._ar_pg = []
# consider all the ranks
ranks = list(range(0, self._world_size))
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new AR group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose:
print(f"created new AR group {i}: {ranks}")
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
if nccl_allgather_channels > 0:
os.putenv('NCCL_MAX_NCHANNELS', str(nccl_allgather_channels))
if self._num_ag_pg == 0:
self._ag_pg = self._ar_pg
self._ag_st = self._ar_st
self._num_ag_pg = self._num_ar_pg
else:
self._ag_pg = []
ranks = []
stride = torch.cuda.device_count()
for i in range(self._num_groups):
rs = list(range(i*stride, (i+1)*stride))
ranks.append(rs)
for rs in ranks:
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=rs)
if torch.distributed.get_rank() in rs:
if self._verbose:
print(f"creating AG group {i}: {rs}")
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
else: # reduce-scatter + all-reduce, need RS, AR, AG groups
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new AR group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose:
print(f"created new AR group {i}: {ranks}")
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
if self._verbose:
print(f"creating RS group : {ranks}")
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
if self._verbose:
print(f"creating AG group : {ranks}")
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.barrier(group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream()
self._step.record_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
self._one = torch.cuda.IntTensor([1])
self._first_step = True
self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
self._param_order = self.AtomicCounter()
p_offset = 0
p_i = 0
self._model_params = []
self._grads_info = []
self._grad_accs = []
self._group_properties = []
for group in self.param_groups:
prev = None
beta1, beta2 = group['betas']
beta3 = 1.0 - beta1 if self._grad_averaging else 1.0
bias_correction = 1 if group['bias_correction'] else 0
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
self._model_params.append(p)
self._group_properties.append((
group['weight_decay'],
1 if group['bias_correction'] else 0,
weight_decay,
bias_correction,
beta1,
beta2,
1.0 - beta1 if grad_averaging else 1.0,
group['eps']
beta3,
eps
))
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
if self._set_flat_param_view:
if param_order:
# this is executed when param_order is specified by the user
self._param_order.add(param_order[p])
else:
self._param_order.add(p_i)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
......@@ -184,7 +298,9 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
if param_order:
self._param_order.order = torch.argsort(torch.tensor(self._param_order.order)).tolist()
self._grads_generated = [False]*len(self._model_params)
self._grads_fp16, self._grads_fp32 = [], []
if self._overlap_reductions:
self._current_block = self._num_blocks
......@@ -193,34 +309,58 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_i = 0
#self._model_params = []
#self._grad_accs = []
#self._group_properties = []
for group in self.param_groups:
for p in group['params']:
torch.distributed.broadcast(p, 0)
if not p.requires_grad:
continue
def wrapper(param, param_i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if not self._set_flat_param_view:
if self._first_step:
# first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
else:
if not self._first_step:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
p_i += 1
self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# initialize master weights, moments buffers if not loaded from checkpoint
if self._fp32_p is None:
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = []
for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
def _flat_split(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
......@@ -228,11 +368,18 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads)
list_of_blocks = __blockify(p)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
def _flat_split_no_shards(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
......@@ -244,7 +391,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._shard_size
......@@ -255,12 +401,86 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
def _split_assign(shards):
packed_block_size = self._num_chunks*self._shard_size
list_of_list_of_chunks=[]
for block_id in range(self._num_blocks):
list_of_chunks=[]
for chunk_id in range(self._num_chunks):
#self._fp16_g[block_id*packed_block_size+chunk_id*self._shard_size:block_id*packed_block_size+(chunk_id+1)*self._shard_size] = shards[block_id][chunk_id][self._rank_in_group]
list_of_chunks.append( shards[block_id][chunk_id][self._rank_in_group])
list_of_list_of_chunks.append(list_of_chunks)
return list_of_list_of_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
# this splitting scheme is needed when allgather needs to be split into multiple chunks in a contiguous way
self._new_params2_blocks, self._new_params2_chunks, self._new_params2_shards = _flat_split(self._new_params)
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
if self._full_ar:
# for gradient all-reduce
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
# for weight update
self._fp16_g_chunks = _split_assign(self._flat_grads_shards)
else:
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
self._lazy_init_stage1_done = True
def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return
if not self._set_flat_param_view:
# reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view
self._param_order.order.reverse()
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
def _get_flat_view(param):
if param.is_contiguous(memory_format=torch.channels_last):
K, C, H, W = param.shape
pv = param.as_strided(size=(K,H,W,C), stride=(H*W*C, W*C, C, 1))
elif param.is_contiguous(memory_format=torch.channels_last_3d):
K, C, D, H, W = param.shape
pv = param.as_strided(size=(K,D,H,W,C), stride=(D*H*W*C, H*W*C, W*C, C, 1))
else:
pv = param
return pv.view(-1)
# re-collect grads info (size, offset) after ordering
prev = None
p_offset = 0
self._grads_info = []
self._individual_flat_grads = []
for i, p in enumerate(self._model_params):
p_grads_size = p.numel()
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))
# for the first iteration
self._do_overlapped_reduction(i, p)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
#print("self._low_param_i", self._low_param_i)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
......@@ -274,7 +494,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_model_param_for_norm_fp16 = []
self._contrib_model_param_for_norm_fp32 = []
self._contrib_model_param_for_norm_is_fp16 = []
self._model_param_is_contrib = [False]*self._model_params_num
self._model_param_is_contrib = []
self._contrib_group_properties = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
......@@ -290,14 +510,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
pf = _get_flat_view(p)
model_param_fragment = pf[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
if model_param_fragment.dtype == torch.float16:
self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )
else:
self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
self._model_param_is_contrib[param_i] = True
self._model_param_is_contrib.append(param_i)
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
......@@ -306,7 +527,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
master_param_fragment.copy_(model_param_fragment)
if not self._resume_from_checkpoint:
master_param_fragment.copy_(model_param_fragment)
self._contrib_group_properties.append(group_props)
self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy
self._contrib_update_frag_for_norm.append(opti_state_u_fragment)
......@@ -322,7 +544,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None
self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
self._model_param_is_contrib = torch.tensor(self._model_param_is_contrib, dtype=torch.bool, device='cuda')
self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device='cuda')
p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list))
self._contrib_compute_update_term_tensor_list = [g, p, m, v, u]
......@@ -340,62 +562,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None
self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream()
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
self._lazy_init_stage2_done = True
def _init_everything(self):
if not self._init_done:
self.__first_step_init__(**self._init_args)
self._init_done = True
self.complete_reductions()
self._first_step = False
def set_is_accumulation_step(self, is_accumulation_step):
self._is_accumulation_step = is_accumulation_step
......@@ -419,9 +589,60 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return flush_block
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size)
def _full_all_reduce_scale(self, block_id, scale):
works = [None]*self._num_chunks
if self._clip_after_ar:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
else:
glob_chunk_id = block_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works0 = torch.distributed.all_reduce(self._flat_grads_blocks[block_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
for i in range(self._num_chunks):
works[i]=works0
self._reductions_works[block_id] = works
def _full_all_reduce(self, block_id):
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _reduce_scatter_and_all_reduce(self, block_id):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
......@@ -430,6 +651,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
......@@ -443,17 +665,66 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
# Compute L2 grad norm
if block_id == 0:
def _pipeline_block_reductions(self, block_id):
if self._clip_after_ar:
self._flatten_grad_mt(1.0/self._world_size)
if self._full_ar:
self._full_all_reduce(block_id)
else:
self._reduce_scatter_and_all_reduce(block_id)
# Compute L2 grad norm
if block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
if 0:#self._full_ar:
l2_grad_norm_sq = self._flat_grads_shards[self._rank_in_group].norm(dtype=torch.float32, p=2)**2
else:
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt()
else:
# Copy model grads to flat grads buffer
self._flatten_grad_mt(1.0)
# Compute L2 grad norm
self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._l2_grad_norm_st):
if not self._fused_norm:
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
# Apply clipping & pre-reduction scaling on grads
loss_scale = self.global_scale
max_grad_norm = loss_scale*self.defaults['max_grad_norm']
coeff = max_grad_norm /(1e-6+self.L2_grad_norm)
coeff = (coeff>1) * self._one + (coeff<=1) * coeff
tmp = torch.cat(((self._one), (coeff)))
index = (coeff+1>coeff).int()
scale = tmp.index_select(0, index).half()/self._world_size
if not self._fuse_scale:
self._flat_grads.mul_(scale)
if self._full_ar:
if self._fuse_scale:
self._full_all_reduce_scale(block_id, scale)
else:
self._full_all_reduce(block_id)
else:
if self._fuse_scale:
self._reduce_scatter_and_all_reduce_scale(block_id, scale)
else:
self._reduce_scatter_and_all_reduce(block_id)
if block_id == 0:
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __compute_contrib_param_norm(self):
if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
......@@ -471,24 +742,34 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def __compute_contrib_update_norm(self):
l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm)
l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm)
torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
l2_norm = torch.sqrt(l2_norm)
return l2_norm.masked_select(self._model_param_is_contrib)
return l2_norm
def _pipeline_step(self):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self.global_scale
max_grad_norm = self.defaults['max_grad_norm']
global_scale = self.global_scale
# if clip before ar, set max_grad_norm to 0
max_grad_norm = self.defaults['max_grad_norm'] * self._clip_after_ar
self._completion_st.wait_stream(self._l2_grad_norm_st)
global_grad_norm = self.L2_grad_norm
if self._clip_grad_norm and max_grad_norm > 0 and math.isfinite(global_grad_norm):
combined_scale = max_grad_norm / (global_grad_norm / self.global_scale + 1e-6)
combined_scale = self.global_scale / min(1, combined_scale)
# check global_grad_norm and fill overflow_buf
is_finite = (global_grad_norm + 1 > global_grad_norm).int()
self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
if not self._clip_after_ar:
torch.distributed.all_reduce(is_finite,
op=torch.distributed.ReduceOp.MIN,
group=self._current_process_group)
torch.distributed.all_reduce(self._overflow_buf,
op=torch.distributed.ReduceOp.MAX,
group=self._current_process_group)
# increment step counter if no overflow
self._step += is_finite
self._completion_st.wait_stream(torch.cuda.current_stream())
self._completion_st.wait_stream(self._l2_grad_norm_st)
# Call step kernel once per step
# Call all-gather once per step
......@@ -504,42 +785,67 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_beta2,
self._contrib_beta3,
self._contrib_bias_correction,
self.param_groups[0]['step'],
self._step,
self._contrib_epsilon,
self._adam_w_mode,
self._contrib_weight_decay,
combined_scale)
global_scale,
global_grad_norm,
max_grad_norm)
upd_norm = self.__compute_contrib_update_norm()
multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf,
self._contrib_update_weights_tensor_list, # u, p, p_copy
param_norm,
upd_norm,
self.param_groups[0]['lr'],
self._offsets,
self._lr,
self._contrib_weight_decay,
global_grad_norm,
self._use_nvlamb)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
if not self._skip_ag:
# allgather chunking is currently not supported for clip after allreduce
if not self._clip_after_ar:
for block in range(self._num_blocks):
for chunk in range(self._num_chunks):
torch.distributed.all_gather(self._new_params2_shards[block][chunk], self._fp16_p_chunks[block][chunk], group=self._ag_pg[0], no_copy=True)
else:
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale)
if not self._fused_norm:
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale)
else:
self._L2_grad_norm=multi_tensor_applier(
amp_C.multi_tensor_l2norm_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale, False)[0].float()
self._grads_fp16 = []
if len(self._grads_fp32) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale)
if not self._fused_norm:
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale)
else:
self._L2_grad_norm=multi_tensor_applier(
amp_C.multi_tensor_l2norm_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale, False)[0].float()
self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
self._init_everything()
def _do_overlapped_reduction(self, param_i, param):
if not self._is_accumulation_step:
# handle overlapped reductions
if param.dtype == torch.float16:
......@@ -547,12 +853,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
else:
self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) )
self._grads_generated[param_i]=True
if self._overlap_reductions and not self._last_step:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
if not self._first_step and not self._last_step:
if self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
......@@ -565,14 +872,12 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
@property
def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
self._init_everything()
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
......@@ -583,7 +888,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
if self._first_step or self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id)
......@@ -593,36 +898,35 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def step(self, closure=None):
def step(self, closure=None, grad_scaler=None):
loss = None
if closure is not None:
loss = closure()
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
for param_group in self.param_groups:
if 'step' in param_group:
param_group['step'] += 1
else:
param_group['step'] = 1
self._pipeline_step()
with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params
self._overflow_buf.zero_()
with torch.no_grad():
if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp16)
if self._packed_flat_to_model_params_fp32 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp32)
if grad_scaler is not None:
found_inf = self._overflow_buf.float()
optimizer_state = grad_scaler._per_optimizer_states[id(self)]
current_device = torch.device('cuda', torch.cuda.current_device())
optimizer_state["found_inf_per_device"][current_device] = found_inf
self._completion_st.wait_stream(torch.cuda.current_stream())
if not self._set_flat_param_view:
with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params
with torch.no_grad():
if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp16)
if self._packed_flat_to_model_params_fp32 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp32)
torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
......@@ -630,4 +934,42 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return loss
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
# save step, master weights and first/second moments
state_dict = {}
state_dict['step'] = self._step
state_dict['fp32_p'] = self._fp32_p
state_dict['fp32_m'] = self._fp32_m
state_dict['fp32_v'] = self._fp32_v
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``optimizer.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# restore step, master weights and first/second moments
self._step = state_dict['step']
self._fp32_p = state_dict['fp32_p'].to(device="cuda")
self._fp32_m = state_dict['fp32_m'].to(device="cuda")
self._fp32_v = state_dict['fp32_v'].to(device="cuda")
self._resume_from_checkpoint = True
......@@ -111,7 +111,7 @@ class FusedLAMB(torch.optim.Optimizer):
continue
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dytpe == torch.float16:
elif p.dtype == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
......
......@@ -9,7 +9,7 @@ from apex.contrib.sparsity import ASP
## Initializing ASP
Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/infercence:
Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference:
```
ASP.prune_trained_model(model, optimizer)
```
......
###############################################################################
# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
###############################################################################
import sys
import torch
import numpy as np
import unittest
import math
import fmhalib as mha
def py_mha(qkv, amask, b, s, h, d):
qkv = qkv.view(b, s, h, 3, d)
q = qkv[:, :, :, 0, :].permute(0,2,1,3)
k = qkv[:, :, :, 1, :].permute(0,2,1,3)
v = qkv[:, :, :, 2, :].permute(0,2,1,3)
p = torch.matmul(q.float(), k.permute(0,1,3,2).float())
p_masked = p / math.sqrt(d) + (1.0 - amask) * -10000.0
s = torch.softmax(p_masked, -1).to(qkv.dtype)
ctx = torch.matmul(s, v)
ctx = ctx.permute(0,2,1,3).contiguous()
ctx.retain_grad()
return ctx
class TestFMHA(unittest.TestCase):
def run_test(self, s, b):
print(f'Test s={s} b={b}')
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
dtype = torch.float16
device = torch.device('cuda')
h = 16
d = 64
slens = [s] * b
a = torch.tensor(np.array([0] + slens), dtype=torch.int32)
amask = torch.ones(b,h,s,s, dtype=dtype, device=device)
seqlens = torch.tensor(slens, dtype=torch.int32, device=device)
cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device)
total = cu_seqlens[-1].item()
qkv = torch.randn((b,s,h,3,d), device=device, dtype=dtype)
qkv_vs = qkv.permute(0,1,3,2,4).contiguous().view(b*s, 3, h,d)
qkv.requires_grad = True
if b < 4:
ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, None)
else:
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, None)
ctx = ctx.view(b,s,h,d)
ctx_ref = py_mha(qkv, amask, b,s,h,d)
self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3))
labels = torch.randn_like(ctx_ref)
diff = ctx_ref - labels
l = (diff * diff).sum() / b
l.backward()
dw = ctx_ref.grad.permute(0,2,1,3)
dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()
if b < 4:
dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
else:
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)
self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
def test_128(self):
self.run_test(128, 32)
def test_256(self):
self.run_test(256, 32)
def test_384(self):
self.run_test(384, 32)
def test_512(self):
self.run_test(512, 32)
self.run_test(512, 2)
self.run_test(512, 3)
if __name__ == '__main__':
unittest.main()
import torch
import unittest
import torch.nn.functional as F
from apex import fused_dense
from torch import nn
from apex import amp
class FusedDenseTest(unittest.TestCase):
def setUp(self, seed=0):
torch.manual_seed(seed)
#torch.cuda.manual_seed_all(seed)
self.seq_length = 512
self.sequences = 3
self.hidden_dim = 1024
self.ref_inputs = torch.randn(self.sequences*self.seq_length, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).int().half().requires_grad_(True)
self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)
self.dense = fused_dense.FusedDense(1024, 3072)
self.dense.half()
self.dense.cuda()
def test_fused_dense(self) :
y_tst = self.dense(self.tst_inputs)
y_ref = torch.matmul(self.ref_inputs,self.dense.weight.t())+self.dense.bias
dy = torch.randn_like(y_tst).half()
y_tst.backward(dy)
dw_ref = torch.matmul(dy.t(), self.ref_inputs)
dx_ref = torch.matmul(dy, self.dense.weight.clone())
db_ref = dy.sum(0, False)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(dw_ref, self.dense.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(dx_ref, self.tst_inputs.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(db_ref, self.dense.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
if __name__ == '__main__':
unittest.main()
import torch
import unittest
from apex.contrib.transducer import TransducerJoint
import transducer_ref
class TransducerJointTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, for_vector_kernel):
self.B = 4
T_min = 51
T_max = 101
U_min = 12
U_max = 25
if for_vector_kernel:
H = 512
else:
H = 509
dtype = torch.float16
device = "cuda"
self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
self.dropout_prob = 0.5
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
# the loss function
for b in range(self.B):
self.h_grad[b, self.f_len[b]:, :, :] = 0
self.h_grad[b, :, self.g_len[b]:, :] = 0
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)
def _pack(self, x, f_len, g_len):
B = x.size(0)
list_x = []
for b in range(B):
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(f_len * g_len, dim=0)
return x_packed
def _unpack(self, x, f_len, g_len):
batch_offset = torch.cumsum(f_len * g_len, dim=0)
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)
B = self.h_grad.size(0)
H = self.h_grad.size(-1)
for b in range(B):
my_batch_offset = 0 if b == 0 else batch_offset[b-1]
my_f_len = f_len[b]
my_g_len = g_len[b]
for t in range(my_f_len):
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len :
my_batch_offset + t*my_g_len + my_g_len]
return x_unpacked
def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):
self.gen_input(for_vector_kernel=for_vector_kernel)
# Generate reference
f_ref = self.f_tst.data.clone()
g_ref = self.g_tst.data.clone()
f_ref.requires_grad = True
g_ref.requires_grad = True
my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout,
dropout_prob=self.dropout_prob, probe_mask=True)
if not pack_output:
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len)
h_tst.backward(self.h_grad)
if dropout:
mask = my_joint.mask_probe[0]
else:
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len,
batch_offset=batch_offset,
packed_batch=batch_offset[-1])
h_tst.backward(self.h_grad_packed)
if dropout:
mask_packed = my_joint.mask_probe[0]
mask = self._unpack(mask_packed, self.f_len, self.g_len)
# reference
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output,
relu=relu,
dropout=dropout,
dropout_prob=self.dropout_prob,
mask=mask if dropout else None)
f_grad_tst = self.f_tst.grad
g_grad_tst = self.g_tst.grad
self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))
def test_transducer_joint(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_vec(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)
def test_transducer_joint_pack(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_vec_pack(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_vec_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)
def test_transducer_joint_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_vec_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
def test_transducer_joint_vec_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)
def test_transducer_joint_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)
def test_transducer_joint_vec_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import torch
import unittest
from apex.contrib.transducer import TransducerLoss
import transducer_ref
class TransducerLossTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, scalar_t, for_vector_kernel):
self.B = 5
T_min = 23
T_max = 51
U_min = 12
U_max = 25
V = 16 if for_vector_kernel else 14
self.blank_idx = V - 1
device = "cuda"
self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True,
device=device)
self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1
self.x_tst_packed, self.batch_offset = self._pack(self.x_tst)
# Generate reference
x_ref = self.x_tst.data.clone()
x_ref.requires_grad = True
loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0)
_, _, self.grad_ref, self.loss_ref \
= transducer_ref.transducer_loss_reference( x=x_ref,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx,
loss_grad=loss_grad)
def _pack(self, x):
list_x = []
for b in range(self.B):
list_x_row = [x[b, t, : self.y_len[b]+1] for t in range(self.f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(self.f_len * (self.y_len+1), dim=0)
return x_packed, batch_offset
def _unpack(self, x):
x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1),
dtype=x.dtype, device=x.device)
for b in range(self.B):
my_batch_offset = 0 if b == 0 else self.batch_offset[b-1]
my_f_len = self.f_len[b]
my_g_len = self.y_len[b] + 1
for t in range(my_f_len):
for u in range(my_g_len):
x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u]
return x_unpacked
def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel):
self.gen_input(scalar_t, for_vector_kernel)
my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward,
packed_input=packed_input)
if not packed_input:
loss_tst = my_loss( x=self.x_tst,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx)
loss_tst.mean().backward()
grad_tst = self.x_tst.grad
else:
loss_tst = my_loss( x=self.x_tst_packed,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx,
batch_offset=self.batch_offset,
max_f_len=max(self.f_len))
loss_tst.mean().backward()
grad_tst_packed = self.x_tst_packed.grad
grad_tst = self._unpack(grad_tst_packed)
return loss_tst, grad_tst
def test_transducer_loss_fp32(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32,
fuse_softmax_backward=False,
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5))
def test_transducer_loss_fp16(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=False,
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed_vec(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True,
for_vector_kernel=True)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import torch
import numpy as np
import pdb
def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):
def log_sum_exp(a, b):
if (a >= b):
return a + torch.log(1 + torch.exp(b-a))
else:
return b + torch.log(1 + torch.exp(a-b))
def forward_alpha(x, label, f_len, y_len, blank_idx):
B, T, U, V = x.size()
acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
alpha = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
for b in range(B):
alpha[b, 0, 0] = 0
for t in range(1, f_len[b]):
alpha[b, t, 0] = alpha[b, t-1, 0] + x[b, t-1, 0, blank_idx]
for u in range(1, y_len[b]+1):
alpha[b, 0, u] = alpha[b, 0, u-1] + x[b, 0, u-1, label[b, u-1]]
for t in range(1, f_len[b]):
for u in range(1, y_len[b]+1):
curr_ = alpha[b, t-1, u] + x[b, t-1, u, blank_idx]
next_ = alpha[b, t, u-1] + x[b, t, u-1, label[b, u-1]]
alpha[b, t, u] = log_sum_exp(curr_, next_)
return alpha
def forward_beta(x, label, f_len, y_len, blank_idx):
B, T, U, V = x.shape
acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
beta = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
for b in range(B):
beta[b, f_len[b]-1, y_len[b]] = x[b, f_len[b]-1, y_len[b], blank_idx]
for t in range(f_len[b]-2, -1, -1):
beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx]
for u in range(y_len[b]-1, -1, -1):
beta[b, f_len[b]-1, u] = beta[b, f_len[b]-1, u+1] + x[b, f_len[b]-1, u, label[b, u]]
for t in range(f_len[b]-2, -1, -1):
for u in range(y_len[b]-1, -1, -1):
curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx]
next_ = beta[b, t, u+1] + x[b, t, u, label[b, u]]
beta[b, t, u] = log_sum_exp(curr_, next_)
return beta
def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx):
grad = torch.zeros_like(x)
B, T, U, V = x.size()
for b in range(B):
common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0]
# next
for u in range(y_len[b]):
grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u]
+ beta[b, :f_len[b], u+1]
+ x[b, :f_len[b], u, label[b, u]])
# current
grad[b, :f_len[b]-1, :y_len[b]+1, blank_idx] \
= -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1]
+ beta[b, 1:f_len[b], :y_len[b]+1]
+ x[b, :f_len[b]-1, :y_len[b]+1, blank_idx])
grad[b, f_len[b]-1, y_len[b], blank_idx] = -torch.exp(common_factor[b, f_len[b]-1, y_len[b]]
+ x[b, f_len[b]-1, y_len[b], blank_idx])
return grad
x_log = torch.nn.functional.log_softmax(x, dim=-1)
alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx)
beta = forward_beta(x_log, label, f_len, y_len, blank_idx)
grad = backward(x_log, label, f_len, y_len, alpha, beta,
loss_grad, blank_idx)
x_log.backward(grad)
loss = -beta[:, 0, 0]
loss = loss.to(x.dtype)
return alpha, beta, x.grad, loss
def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout,
dropout_prob=0, mask=None):
if dropout and mask == None:
raise NotImplementedError("mask needs to supplied to test dropout.")
B, T, H = f.size()
U = g.size(1)
f_expand = f.unsqueeze(dim=2)
g_expand = g.unsqueeze(dim=1)
h = f_expand + g_expand
if relu:
h = torch.nn.functional.relu(h)
if dropout:
h *= mask
scale = 1/(1-dropout_prob)
h *= scale
h.backward(h_grad)
if pack_output == False:
# intentionally set don't-care region to -1 to test if transducer joint
# write these regions to avoid NaN and inf
for b in range(B):
h[b, f_len[b]:] = -1
h[b, :, g_len[b]:] = -1
return h, f.grad, g.grad
# packing
list_to_pack = []
for b in range(B):
list_to_pack.append(h[b, :f_len[b], :g_len[b], :].reshape(-1, H))
h_packed = torch.cat(list_to_pack)
return h_packed, f.grad, g.grad
from .transducer import TransducerJoint
from .transducer import TransducerLoss
\ No newline at end of file
import torch
import transducer_loss_cuda
import transducer_joint_cuda
class TransducerJoint(torch.nn.Module):
"""Transducer joint
Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural
Networks
Arguments:
pack_output (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1
(default: False)
dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1
(default: False)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm.
(default: 1)
fwd_tile_size (int, optional): tile size used in forward operation. This argument will be
ignored if opt != 1. (default: 4)
dropout_prob (float, optional): dropout probability. (default: 0.0)
probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout
operation. When this argument is set to True, the mask can be accessed through
self.mask_probe. (default: false)
"""
def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4,
dropout_prob=0, probe_mask=False):
super(TransducerJoint, self).__init__()
self.pack_output = pack_output
self.relu = relu
self.dropout = dropout
self.dropout_prob = dropout_prob
self.opt = opt
self.fwd_tile_size = fwd_tile_size
self.dummy_batch_offset = torch.empty(0)
masked = self.relu or self.dropout
self.mask_probe = [] if masked and probe_mask else None
if masked and opt != 1:
raise NotImplementedError("ReLU and dropout fusion is only supported with opt=1")
def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):
"""Forward operation of transducer joint
Arguments:
f (tensor): transcription vector from encode block of shape (B, T, H).
g (tensor): prediction vector form predict block of shape (B, U, H).
f_len (tensor): length of transcription vector for each batch.
g_len (tensor): length of prediction vector minus 1 for each batch.
batch_offset (tensor, optional): tensor containing the offset of each batch
in the results. For example, batch offset can be obtained from:
batch_offset = torch.cumsum(f_len*g_len, dim=0)
This argument is required if pack_output == True, and is ignored if
pack_output == False. (default: None)
packed_batch (int, optional): the batch size after packing. This argument is
ignored if pack_output == False. (default: 0)
"""
my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset
if self.pack_output and (batch_offset is None or packed_batch == 0):
raise Exception("Please specify batch_offset and packed_batch when packing is enabled")
dropout = self.dropout and self.training # only dropout for training
return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout,
my_batch_offset, packed_batch, self.opt,
self.fwd_tile_size, self.dropout_prob, self.mask_probe)
class TransducerLoss(torch.nn.Module):
"""Transducer loss
Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural
Networks
Arguments:
fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with
softmax. (default: True)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized
algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1)
packed_input (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
"""
def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False):
super(TransducerLoss, self).__init__()
self.fuse_softmax_backward = fuse_softmax_backward
self.opt = opt
self.packed_input = packed_input
self.dummy_batch_offset = torch.empty(0)
def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None,
debug_list=None):
"""Forward operation of transducer joint
Arguments:
x (tensor): input tensor to the loss function with a shape of (B, T, U, H).
label (tensor): labels for the input data.
f_len (tensor): lengths of the inputs in the time dimension for each batch.
y_len (tensor): lengths of the labels for each batch.
blank_idx (int): index for the null symbol.
batch_offset (tensor, optional): tensor containing the offset of each batch
in the input. For example, batch offset can be obtained from:
batch_offset = torch.cumsum(f_len*(y_len+1), dim=0)
This argument is required if packed_input == True, and is ignored if
packed_input == False. (default: None)
max_f_len (int, optional): maximum length of the input in the time dimension.
For example, it can be obtained as
max_f_len = max(f_len)
This argument is required if packed_input == True, and is ignored if
packed_input == False. (default: None)
(default: None)
debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated
in the forward operation will be attached to this list for debug purpose.
(default: None)
"""
if self.packed_input:
if batch_offset is None or max_f_len is None:
raise Exception("Please specify batch_offset and max_f_len when packing is \
enabled")
my_batch_offset = batch_offset
my_max_f_len = max_f_len
else:
my_batch_offset = self.dummy_batch_offset
my_max_f_len = x.size(1)
return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len,
blank_idx, self.fuse_softmax_backward, debug_list,
self.opt, self.packed_input)
class TransducerLossFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx,
fuse_softmax_backward, debug_list, opt, packed_input):
if fuse_softmax_backward == False:
with torch.enable_grad():
x = torch.nn.functional.log_softmax(x, dim=-1)
else:
x = torch.nn.functional.log_softmax(x, dim=-1)
alpha, beta, loss = transducer_loss_cuda.forward( x, label, f_len, y_len, batch_offset,
max_f_len, blank_idx, opt, packed_input)
if debug_list == []:
debug_list += [alpha, beta]
ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset)
ctx.blank_idx = blank_idx
ctx.fuse_softmax_backward = fuse_softmax_backward
ctx.opt = opt
ctx.packed_input = packed_input
ctx.max_f_len = max_f_len
return loss
@staticmethod
def backward(ctx, loss_grad):
x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors
x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label,
batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt,
ctx.fuse_softmax_backward, ctx.packed_input)
if ctx.fuse_softmax_backward == False:
x_grad = x.backward(x_grad)
return x_grad, None, None, None, None, None, None, None, None, None, None
class TransducerJointFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch,
opt, fwd_tile_size, dropout_prob, mask_probe):
h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt,
pack_output, relu, dropout, dropout_prob, fwd_tile_size)
masked = relu or dropout
if masked:
ctx.save_for_backward(h[1], f_len, g_len, batch_offset)
if mask_probe is not None:
mask_probe.append(h[1])
else:
ctx.save_for_backward(f_len, g_len, batch_offset)
ctx.pack_output = pack_output
ctx.masked = relu or dropout
ctx.max_f_len = f.size(1)
ctx.max_g_len = g.size(1)
ctx.scale = 1 / (1-dropout_prob) if dropout and dropout_prob != 1 else 1
return h[0]
@staticmethod
def backward(ctx, loss_grad):
if ctx.masked:
mask, f_len, g_len, batch_offset = ctx.saved_tensors
inp = [loss_grad, mask]
else:
f_len, g_len, batch_offset = ctx.saved_tensors
inp = [loss_grad]
f_grad, g_grad = transducer_joint_cuda.backward( inp, f_len, g_len, batch_offset,
ctx.max_f_len, ctx.max_g_len,
ctx.pack_output, ctx.scale)
return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \
None, None, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment